Skip to content

Commit d44882c

Browse files
authored
refactor: reduce duplciate code by inheritance (#13073)
1 parent 23c68ef commit d44882c

File tree

1 file changed

+9
-187
lines changed

1 file changed

+9
-187
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,13 @@
1-
import json
2-
import time
3-
from decimal import Decimal
41
from typing import Optional
5-
from urllib.parse import urljoin
6-
7-
import numpy as np
8-
import requests
92

103
from core.entities.embedding_type import EmbeddingInputType
11-
from core.model_runtime.entities.common_entities import I18nObject
12-
from core.model_runtime.entities.model_entities import (
13-
AIModelEntity,
14-
FetchFrom,
15-
ModelPropertyKey,
16-
ModelType,
17-
PriceConfig,
18-
PriceType,
4+
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
5+
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
6+
OAICompatEmbeddingModel,
197
)
20-
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
21-
from core.model_runtime.errors.validate import CredentialsValidateFailedError
22-
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
23-
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
248

259

26-
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
10+
class PerfXCloudEmbeddingModel(OAICompatEmbeddingModel):
2711
"""
2812
Model class for an OpenAI API-compatible text embedding model.
2913
"""
@@ -47,86 +31,10 @@ def _invoke(
4731
:return: embeddings result
4832
"""
4933

50-
# Prepare headers and payload for the request
51-
headers = {"Content-Type": "application/json"}
52-
53-
api_key = credentials.get("api_key")
54-
if api_key:
55-
headers["Authorization"] = f"Bearer {api_key}"
56-
endpoint_url: Optional[str]
5734
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
58-
endpoint_url = "https://cloud.perfxlab.cn/v1/"
59-
else:
60-
endpoint_url = credentials.get("endpoint_url")
61-
assert endpoint_url is not None, "endpoint_url is required in credentials"
62-
if not endpoint_url.endswith("/"):
63-
endpoint_url += "/"
64-
65-
assert isinstance(endpoint_url, str)
66-
endpoint_url = urljoin(endpoint_url, "embeddings")
67-
68-
extra_model_kwargs = {}
69-
if user:
70-
extra_model_kwargs["user"] = user
71-
72-
extra_model_kwargs["encoding_format"] = "float"
73-
74-
# get model properties
75-
context_size = self._get_context_size(model, credentials)
76-
max_chunks = self._get_max_chunks(model, credentials)
77-
78-
inputs = []
79-
indices = []
80-
used_tokens = 0
81-
82-
for i, text in enumerate(texts):
83-
# Here token count is only an approximation based on the GPT2 tokenizer
84-
# TODO: Optimize for better token estimation and chunking
85-
num_tokens = self._get_num_tokens_by_gpt2(text)
86-
87-
if num_tokens >= context_size:
88-
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
89-
# if num tokens is larger than context length, only use the start
90-
inputs.append(text[0:cutoff])
91-
else:
92-
inputs.append(text)
93-
indices += [i]
94-
95-
batched_embeddings = []
96-
_iter = range(0, len(inputs), max_chunks)
97-
98-
for i in _iter:
99-
# Prepare the payload for the request
100-
payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs}
101-
102-
# Make the request to the OpenAI API
103-
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
35+
credentials["endpoint_url"] = "https://cloud.perfxlab.cn/v1/"
10436

105-
response.raise_for_status() # Raise an exception for HTTP errors
106-
response_data = response.json()
107-
108-
# Extract embeddings and used tokens from the response
109-
embeddings_batch = [data["embedding"] for data in response_data["data"]]
110-
embedding_used_tokens = response_data["usage"]["total_tokens"]
111-
112-
used_tokens += embedding_used_tokens
113-
batched_embeddings += embeddings_batch
114-
115-
# calc usage
116-
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
117-
118-
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
119-
120-
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
121-
"""
122-
Approximate number of tokens for given messages using GPT2 tokenizer
123-
124-
:param model: model name
125-
:param credentials: model credentials
126-
:param texts: texts to embed
127-
:return:
128-
"""
129-
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
37+
return OAICompatEmbeddingModel._invoke(self, model, credentials, texts, user, input_type)
13038

13139
def validate_credentials(self, model: str, credentials: dict) -> None:
13240
"""
@@ -136,93 +44,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
13644
:param credentials: model credentials
13745
:return:
13846
"""
139-
try:
140-
headers = {"Content-Type": "application/json"}
141-
142-
api_key = credentials.get("api_key")
143-
144-
if api_key:
145-
headers["Authorization"] = f"Bearer {api_key}"
146-
147-
endpoint_url: Optional[str]
148-
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
149-
endpoint_url = "https://cloud.perfxlab.cn/v1/"
150-
else:
151-
endpoint_url = credentials.get("endpoint_url")
152-
assert endpoint_url is not None, "endpoint_url is required in credentials"
153-
if not endpoint_url.endswith("/"):
154-
endpoint_url += "/"
155-
156-
assert isinstance(endpoint_url, str)
157-
endpoint_url = urljoin(endpoint_url, "embeddings")
158-
159-
payload = {"input": "ping", "model": model}
160-
161-
response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
162-
163-
if response.status_code != 200:
164-
raise CredentialsValidateFailedError(
165-
f"Credentials validation failed with status code {response.status_code}"
166-
)
167-
168-
try:
169-
json_result = response.json()
170-
except json.JSONDecodeError as e:
171-
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
172-
173-
if "model" not in json_result:
174-
raise CredentialsValidateFailedError("Credentials validation failed: invalid response")
175-
except CredentialsValidateFailedError:
176-
raise
177-
except Exception as ex:
178-
raise CredentialsValidateFailedError(str(ex))
179-
180-
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
181-
"""
182-
generate custom model entities from credentials
183-
"""
184-
entity = AIModelEntity(
185-
model=model,
186-
label=I18nObject(en_US=model),
187-
model_type=ModelType.TEXT_EMBEDDING,
188-
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
189-
model_properties={
190-
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
191-
ModelPropertyKey.MAX_CHUNKS: 1,
192-
},
193-
parameter_rules=[],
194-
pricing=PriceConfig(
195-
input=Decimal(credentials.get("input_price", 0)),
196-
unit=Decimal(credentials.get("unit", 0)),
197-
currency=credentials.get("currency", "USD"),
198-
),
199-
)
200-
201-
return entity
202-
203-
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
204-
"""
205-
Calculate response usage
206-
207-
:param model: model name
208-
:param credentials: model credentials
209-
:param tokens: input tokens
210-
:return: usage
211-
"""
212-
# get input price info
213-
input_price_info = self.get_price(
214-
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
215-
)
216-
217-
# transform usage
218-
usage = EmbeddingUsage(
219-
tokens=tokens,
220-
total_tokens=tokens,
221-
unit_price=input_price_info.unit_price,
222-
price_unit=input_price_info.unit,
223-
total_price=input_price_info.total_amount,
224-
currency=input_price_info.currency,
225-
latency=time.perf_counter() - self.started_at,
226-
)
47+
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
48+
credentials["endpoint_url"] = "https://cloud.perfxlab.cn/v1/"
22749

228-
return usage
50+
OAICompatEmbeddingModel.validate_credentials(self, model, credentials)

0 commit comments

Comments
 (0)