4
4
import traceback
5
5
import types
6
6
from enum import Enum
7
- from typing import Callable , Optional
7
+ from typing import Any , Callable , Optional , Union
8
8
9
9
import httpx # type: ignore
10
10
import requests # type: ignore
11
11
12
12
import litellm
13
+ from litellm .litellm_core_utils .litellm_logging import Logging as LiteLLMLoggingObj
14
+ from litellm .llms .custom_httpx .http_handler import AsyncHTTPHandler , HTTPHandler
13
15
from litellm .utils import Choices , Message , ModelResponse , Usage
14
16
15
17
@@ -246,14 +248,98 @@ def completion(
246
248
return model_response
247
249
248
250
251
+ def _process_embedding_response (
252
+ embeddings : list ,
253
+ model_response : litellm .EmbeddingResponse ,
254
+ model : str ,
255
+ encoding : Any ,
256
+ input : list ,
257
+ ) -> litellm .EmbeddingResponse :
258
+ output_data = []
259
+ for idx , embedding in enumerate (embeddings ):
260
+ output_data .append (
261
+ {"object" : "embedding" , "index" : idx , "embedding" : embedding }
262
+ )
263
+ model_response .object = "list"
264
+ model_response .data = output_data
265
+ model_response .model = model
266
+ input_tokens = 0
267
+ for text in input :
268
+ input_tokens += len (encoding .encode (text ))
269
+
270
+ setattr (
271
+ model_response ,
272
+ "usage" ,
273
+ Usage (
274
+ prompt_tokens = input_tokens , completion_tokens = 0 , total_tokens = input_tokens
275
+ ),
276
+ )
277
+
278
+ return model_response
279
+
280
+
281
+ async def async_embedding (
282
+ model : str ,
283
+ data : dict ,
284
+ input : list ,
285
+ model_response : litellm .utils .EmbeddingResponse ,
286
+ timeout : Union [float , httpx .Timeout ],
287
+ logging_obj : LiteLLMLoggingObj ,
288
+ optional_params : dict ,
289
+ api_base : str ,
290
+ api_key : Optional [str ],
291
+ headers : dict ,
292
+ encoding : Callable ,
293
+ client : Optional [AsyncHTTPHandler ] = None ,
294
+ ):
295
+
296
+ ## LOGGING
297
+ logging_obj .pre_call (
298
+ input = input ,
299
+ api_key = api_key ,
300
+ additional_args = {
301
+ "complete_input_dict" : data ,
302
+ "headers" : headers ,
303
+ "api_base" : api_base ,
304
+ },
305
+ )
306
+ ## COMPLETION CALL
307
+ if client is None :
308
+ client = AsyncHTTPHandler (concurrent_limit = 1 )
309
+
310
+ response = await client .post (api_base , headers = headers , data = json .dumps (data ))
311
+
312
+ ## LOGGING
313
+ logging_obj .post_call (
314
+ input = input ,
315
+ api_key = api_key ,
316
+ additional_args = {"complete_input_dict" : data },
317
+ original_response = response ,
318
+ )
319
+
320
+ embeddings = response .json ()["embeddings" ]
321
+
322
+ ## PROCESS RESPONSE ##
323
+ return _process_embedding_response (
324
+ embeddings = embeddings ,
325
+ model_response = model_response ,
326
+ model = model ,
327
+ encoding = encoding ,
328
+ input = input ,
329
+ )
330
+
331
+
249
332
def embedding (
250
333
model : str ,
251
334
input : list ,
252
335
model_response : litellm .EmbeddingResponse ,
336
+ logging_obj : LiteLLMLoggingObj ,
337
+ optional_params : dict ,
338
+ encoding : Any ,
253
339
api_key : Optional [str ] = None ,
254
- logging_obj = None ,
255
- encoding = None ,
256
- optional_params = None ,
340
+ aembedding : Optional [ bool ] = None ,
341
+ timeout : Union [ float , httpx . Timeout ] = httpx . Timeout ( None ) ,
342
+ client : Optional [ Union [ HTTPHandler , AsyncHTTPHandler ]] = None ,
257
343
):
258
344
headers = validate_environment (api_key )
259
345
embed_url = "https://api.cohere.ai/v1/embed"
@@ -270,8 +356,26 @@ def embedding(
270
356
api_key = api_key ,
271
357
additional_args = {"complete_input_dict" : data },
272
358
)
359
+
360
+ ## ROUTING
361
+ if aembedding is True :
362
+ return async_embedding (
363
+ model = model ,
364
+ data = data ,
365
+ input = input ,
366
+ model_response = model_response ,
367
+ timeout = timeout ,
368
+ logging_obj = logging_obj ,
369
+ optional_params = optional_params ,
370
+ api_base = embed_url ,
371
+ api_key = api_key ,
372
+ headers = headers ,
373
+ encoding = encoding ,
374
+ )
273
375
## COMPLETION CALL
274
- response = requests .post (embed_url , headers = headers , data = json .dumps (data ))
376
+ if client is None or not isinstance (client , HTTPHandler ):
377
+ client = HTTPHandler (concurrent_limit = 1 )
378
+ response = client .post (embed_url , headers = headers , data = json .dumps (data ))
275
379
## LOGGING
276
380
logging_obj .post_call (
277
381
input = input ,
@@ -293,23 +397,11 @@ def embedding(
293
397
if response .status_code != 200 :
294
398
raise CohereError (message = response .text , status_code = response .status_code )
295
399
embeddings = response .json ()["embeddings" ]
296
- output_data = []
297
- for idx , embedding in enumerate (embeddings ):
298
- output_data .append (
299
- {"object" : "embedding" , "index" : idx , "embedding" : embedding }
300
- )
301
- model_response .object = "list"
302
- model_response .data = output_data
303
- model_response .model = model
304
- input_tokens = 0
305
- for text in input :
306
- input_tokens += len (encoding .encode (text ))
307
400
308
- setattr (
309
- model_response ,
310
- "usage" ,
311
- Usage (
312
- prompt_tokens = input_tokens , completion_tokens = 0 , total_tokens = input_tokens
313
- ) ,
401
+ return _process_embedding_response (
402
+ embeddings = embeddings ,
403
+ model_response = model_response ,
404
+ model = model ,
405
+ encoding = encoding ,
406
+ input = input ,
314
407
)
315
- return model_response
0 commit comments