7
7
import traceback
8
8
import types
9
9
from enum import Enum
10
- from typing import Callable , Optional
10
+ from typing import Any , Callable , Optional , Union
11
11
12
12
import httpx # type: ignore
13
13
import requests # type: ignore
14
14
15
15
import litellm
16
+ from litellm .litellm_core_utils .litellm_logging import Logging as LiteLLMLoggingObj
17
+ from litellm .llms .custom_httpx .http_handler import AsyncHTTPHandler , HTTPHandler
16
18
from litellm .utils import Choices , Message , ModelResponse , Usage
17
19
18
20
@@ -249,14 +251,98 @@ def completion(
249
251
return model_response
250
252
251
253
254
+ def _process_embedding_response (
255
+ embeddings : list ,
256
+ model_response : litellm .EmbeddingResponse ,
257
+ model : str ,
258
+ encoding : Any ,
259
+ input : list ,
260
+ ) -> litellm .EmbeddingResponse :
261
+ output_data = []
262
+ for idx , embedding in enumerate (embeddings ):
263
+ output_data .append (
264
+ {"object" : "embedding" , "index" : idx , "embedding" : embedding }
265
+ )
266
+ model_response .object = "list"
267
+ model_response .data = output_data
268
+ model_response .model = model
269
+ input_tokens = 0
270
+ for text in input :
271
+ input_tokens += len (encoding .encode (text ))
272
+
273
+ setattr (
274
+ model_response ,
275
+ "usage" ,
276
+ Usage (
277
+ prompt_tokens = input_tokens , completion_tokens = 0 , total_tokens = input_tokens
278
+ ),
279
+ )
280
+
281
+ return model_response
282
+
283
+
284
+ async def async_embedding (
285
+ model : str ,
286
+ data : dict ,
287
+ input : list ,
288
+ model_response : litellm .utils .EmbeddingResponse ,
289
+ timeout : Union [float , httpx .Timeout ],
290
+ logging_obj : LiteLLMLoggingObj ,
291
+ optional_params : dict ,
292
+ api_base : str ,
293
+ api_key : Optional [str ],
294
+ headers : dict ,
295
+ encoding : Callable ,
296
+ client : Optional [AsyncHTTPHandler ] = None ,
297
+ ):
298
+
299
+ ## LOGGING
300
+ logging_obj .pre_call (
301
+ input = input ,
302
+ api_key = api_key ,
303
+ additional_args = {
304
+ "complete_input_dict" : data ,
305
+ "headers" : headers ,
306
+ "api_base" : api_base ,
307
+ },
308
+ )
309
+ ## COMPLETION CALL
310
+ if client is None :
311
+ client = AsyncHTTPHandler (concurrent_limit = 1 )
312
+
313
+ response = await client .post (api_base , headers = headers , data = json .dumps (data ))
314
+
315
+ ## LOGGING
316
+ logging_obj .post_call (
317
+ input = input ,
318
+ api_key = api_key ,
319
+ additional_args = {"complete_input_dict" : data },
320
+ original_response = response ,
321
+ )
322
+
323
+ embeddings = response .json ()["embeddings" ]
324
+
325
+ ## PROCESS RESPONSE ##
326
+ return _process_embedding_response (
327
+ embeddings = embeddings ,
328
+ model_response = model_response ,
329
+ model = model ,
330
+ encoding = encoding ,
331
+ input = input ,
332
+ )
333
+
334
+
252
335
def embedding (
253
336
model : str ,
254
337
input : list ,
255
338
model_response : litellm .EmbeddingResponse ,
339
+ logging_obj : LiteLLMLoggingObj ,
340
+ optional_params : dict ,
341
+ encoding : Any ,
256
342
api_key : Optional [str ] = None ,
257
- logging_obj = None ,
258
- encoding = None ,
259
- optional_params = None ,
343
+ aembedding : Optional [ bool ] = None ,
344
+ timeout : Union [ float , httpx . Timeout ] = httpx . Timeout ( None ) ,
345
+ client : Optional [ Union [ HTTPHandler , AsyncHTTPHandler ]] = None ,
260
346
):
261
347
headers = validate_environment (api_key )
262
348
embed_url = "https://api.cohere.ai/v1/embed"
@@ -273,8 +359,26 @@ def embedding(
273
359
api_key = api_key ,
274
360
additional_args = {"complete_input_dict" : data },
275
361
)
362
+
363
+ ## ROUTING
364
+ if aembedding is True :
365
+ return async_embedding (
366
+ model = model ,
367
+ data = data ,
368
+ input = input ,
369
+ model_response = model_response ,
370
+ timeout = timeout ,
371
+ logging_obj = logging_obj ,
372
+ optional_params = optional_params ,
373
+ api_base = embed_url ,
374
+ api_key = api_key ,
375
+ headers = headers ,
376
+ encoding = encoding ,
377
+ )
276
378
## COMPLETION CALL
277
- response = requests .post (embed_url , headers = headers , data = json .dumps (data ))
379
+ if client is None or not isinstance (client , HTTPHandler ):
380
+ client = HTTPHandler (concurrent_limit = 1 )
381
+ response = client .post (embed_url , headers = headers , data = json .dumps (data ))
278
382
## LOGGING
279
383
logging_obj .post_call (
280
384
input = input ,
@@ -296,23 +400,11 @@ def embedding(
296
400
if response .status_code != 200 :
297
401
raise CohereError (message = response .text , status_code = response .status_code )
298
402
embeddings = response .json ()["embeddings" ]
299
- output_data = []
300
- for idx , embedding in enumerate (embeddings ):
301
- output_data .append (
302
- {"object" : "embedding" , "index" : idx , "embedding" : embedding }
303
- )
304
- model_response .object = "list"
305
- model_response .data = output_data
306
- model_response .model = model
307
- input_tokens = 0
308
- for text in input :
309
- input_tokens += len (encoding .encode (text ))
310
403
311
- setattr (
312
- model_response ,
313
- "usage" ,
314
- Usage (
315
- prompt_tokens = input_tokens , completion_tokens = 0 , total_tokens = input_tokens
316
- ) ,
404
+ return _process_embedding_response (
405
+ embeddings = embeddings ,
406
+ model_response = model_response ,
407
+ model = model ,
408
+ encoding = encoding ,
409
+ input = input ,
317
410
)
318
- return model_response
0 commit comments