1
- import json
2
- import time
3
- from decimal import Decimal
4
1
from typing import Optional
5
- from urllib .parse import urljoin
6
-
7
- import numpy as np
8
- import requests
9
2
10
3
from core .entities .embedding_type import EmbeddingInputType
11
- from core .model_runtime .entities .common_entities import I18nObject
12
- from core .model_runtime .entities .model_entities import (
13
- AIModelEntity ,
14
- FetchFrom ,
15
- ModelPropertyKey ,
16
- ModelType ,
17
- PriceConfig ,
18
- PriceType ,
4
+ from core .model_runtime .entities .text_embedding_entities import TextEmbeddingResult
5
+ from core .model_runtime .model_providers .openai_api_compatible .text_embedding .text_embedding import (
6
+ OAICompatEmbeddingModel ,
19
7
)
20
- from core .model_runtime .entities .text_embedding_entities import EmbeddingUsage , TextEmbeddingResult
21
- from core .model_runtime .errors .validate import CredentialsValidateFailedError
22
- from core .model_runtime .model_providers .__base .text_embedding_model import TextEmbeddingModel
23
- from core .model_runtime .model_providers .openai_api_compatible ._common import _CommonOaiApiCompat
24
8
25
9
26
- class OAICompatEmbeddingModel ( _CommonOaiApiCompat , TextEmbeddingModel ):
10
+ class PerfXCloudEmbeddingModel ( OAICompatEmbeddingModel ):
27
11
"""
28
12
Model class for an OpenAI API-compatible text embedding model.
29
13
"""
@@ -47,86 +31,10 @@ def _invoke(
47
31
:return: embeddings result
48
32
"""
49
33
50
- # Prepare headers and payload for the request
51
- headers = {"Content-Type" : "application/json" }
52
-
53
- api_key = credentials .get ("api_key" )
54
- if api_key :
55
- headers ["Authorization" ] = f"Bearer { api_key } "
56
- endpoint_url : Optional [str ]
57
34
if "endpoint_url" not in credentials or credentials ["endpoint_url" ] == "" :
58
- endpoint_url = "https://cloud.perfxlab.cn/v1/"
59
- else :
60
- endpoint_url = credentials .get ("endpoint_url" )
61
- assert endpoint_url is not None , "endpoint_url is required in credentials"
62
- if not endpoint_url .endswith ("/" ):
63
- endpoint_url += "/"
64
-
65
- assert isinstance (endpoint_url , str )
66
- endpoint_url = urljoin (endpoint_url , "embeddings" )
67
-
68
- extra_model_kwargs = {}
69
- if user :
70
- extra_model_kwargs ["user" ] = user
71
-
72
- extra_model_kwargs ["encoding_format" ] = "float"
73
-
74
- # get model properties
75
- context_size = self ._get_context_size (model , credentials )
76
- max_chunks = self ._get_max_chunks (model , credentials )
77
-
78
- inputs = []
79
- indices = []
80
- used_tokens = 0
81
-
82
- for i , text in enumerate (texts ):
83
- # Here token count is only an approximation based on the GPT2 tokenizer
84
- # TODO: Optimize for better token estimation and chunking
85
- num_tokens = self ._get_num_tokens_by_gpt2 (text )
86
-
87
- if num_tokens >= context_size :
88
- cutoff = int (np .floor (len (text ) * (context_size / num_tokens )))
89
- # if num tokens is larger than context length, only use the start
90
- inputs .append (text [0 :cutoff ])
91
- else :
92
- inputs .append (text )
93
- indices += [i ]
94
-
95
- batched_embeddings = []
96
- _iter = range (0 , len (inputs ), max_chunks )
97
-
98
- for i in _iter :
99
- # Prepare the payload for the request
100
- payload = {"input" : inputs [i : i + max_chunks ], "model" : model , ** extra_model_kwargs }
101
-
102
- # Make the request to the OpenAI API
103
- response = requests .post (endpoint_url , headers = headers , data = json .dumps (payload ), timeout = (10 , 300 ))
35
+ credentials ["endpoint_url" ] = "https://cloud.perfxlab.cn/v1/"
104
36
105
- response .raise_for_status () # Raise an exception for HTTP errors
106
- response_data = response .json ()
107
-
108
- # Extract embeddings and used tokens from the response
109
- embeddings_batch = [data ["embedding" ] for data in response_data ["data" ]]
110
- embedding_used_tokens = response_data ["usage" ]["total_tokens" ]
111
-
112
- used_tokens += embedding_used_tokens
113
- batched_embeddings += embeddings_batch
114
-
115
- # calc usage
116
- usage = self ._calc_response_usage (model = model , credentials = credentials , tokens = used_tokens )
117
-
118
- return TextEmbeddingResult (embeddings = batched_embeddings , usage = usage , model = model )
119
-
120
- def get_num_tokens (self , model : str , credentials : dict , texts : list [str ]) -> int :
121
- """
122
- Approximate number of tokens for given messages using GPT2 tokenizer
123
-
124
- :param model: model name
125
- :param credentials: model credentials
126
- :param texts: texts to embed
127
- :return:
128
- """
129
- return sum (self ._get_num_tokens_by_gpt2 (text ) for text in texts )
37
+ return OAICompatEmbeddingModel ._invoke (self , model , credentials , texts , user , input_type )
130
38
131
39
def validate_credentials (self , model : str , credentials : dict ) -> None :
132
40
"""
@@ -136,93 +44,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
136
44
:param credentials: model credentials
137
45
:return:
138
46
"""
139
- try :
140
- headers = {"Content-Type" : "application/json" }
141
-
142
- api_key = credentials .get ("api_key" )
143
-
144
- if api_key :
145
- headers ["Authorization" ] = f"Bearer { api_key } "
146
-
147
- endpoint_url : Optional [str ]
148
- if "endpoint_url" not in credentials or credentials ["endpoint_url" ] == "" :
149
- endpoint_url = "https://cloud.perfxlab.cn/v1/"
150
- else :
151
- endpoint_url = credentials .get ("endpoint_url" )
152
- assert endpoint_url is not None , "endpoint_url is required in credentials"
153
- if not endpoint_url .endswith ("/" ):
154
- endpoint_url += "/"
155
-
156
- assert isinstance (endpoint_url , str )
157
- endpoint_url = urljoin (endpoint_url , "embeddings" )
158
-
159
- payload = {"input" : "ping" , "model" : model }
160
-
161
- response = requests .post (url = endpoint_url , headers = headers , data = json .dumps (payload ), timeout = (10 , 300 ))
162
-
163
- if response .status_code != 200 :
164
- raise CredentialsValidateFailedError (
165
- f"Credentials validation failed with status code { response .status_code } "
166
- )
167
-
168
- try :
169
- json_result = response .json ()
170
- except json .JSONDecodeError as e :
171
- raise CredentialsValidateFailedError ("Credentials validation failed: JSON decode error" )
172
-
173
- if "model" not in json_result :
174
- raise CredentialsValidateFailedError ("Credentials validation failed: invalid response" )
175
- except CredentialsValidateFailedError :
176
- raise
177
- except Exception as ex :
178
- raise CredentialsValidateFailedError (str (ex ))
179
-
180
- def get_customizable_model_schema (self , model : str , credentials : dict ) -> AIModelEntity :
181
- """
182
- generate custom model entities from credentials
183
- """
184
- entity = AIModelEntity (
185
- model = model ,
186
- label = I18nObject (en_US = model ),
187
- model_type = ModelType .TEXT_EMBEDDING ,
188
- fetch_from = FetchFrom .CUSTOMIZABLE_MODEL ,
189
- model_properties = {
190
- ModelPropertyKey .CONTEXT_SIZE : int (credentials .get ("context_size" , 512 )),
191
- ModelPropertyKey .MAX_CHUNKS : 1 ,
192
- },
193
- parameter_rules = [],
194
- pricing = PriceConfig (
195
- input = Decimal (credentials .get ("input_price" , 0 )),
196
- unit = Decimal (credentials .get ("unit" , 0 )),
197
- currency = credentials .get ("currency" , "USD" ),
198
- ),
199
- )
200
-
201
- return entity
202
-
203
- def _calc_response_usage (self , model : str , credentials : dict , tokens : int ) -> EmbeddingUsage :
204
- """
205
- Calculate response usage
206
-
207
- :param model: model name
208
- :param credentials: model credentials
209
- :param tokens: input tokens
210
- :return: usage
211
- """
212
- # get input price info
213
- input_price_info = self .get_price (
214
- model = model , credentials = credentials , price_type = PriceType .INPUT , tokens = tokens
215
- )
216
-
217
- # transform usage
218
- usage = EmbeddingUsage (
219
- tokens = tokens ,
220
- total_tokens = tokens ,
221
- unit_price = input_price_info .unit_price ,
222
- price_unit = input_price_info .unit ,
223
- total_price = input_price_info .total_amount ,
224
- currency = input_price_info .currency ,
225
- latency = time .perf_counter () - self .started_at ,
226
- )
47
+ if "endpoint_url" not in credentials or credentials ["endpoint_url" ] == "" :
48
+ credentials ["endpoint_url" ] = "https://cloud.perfxlab.cn/v1/"
227
49
228
- return usage
50
+ OAICompatEmbeddingModel . validate_credentials ( self , model , credentials )
0 commit comments