Skip to content

Commit 5d17016

Browse files
fix(nvidia_nim/embed.py): add 'dimensions' support (#8302)
* fix(nvidia_nim/embed.py): add 'dimensions' support Fixes #8238 * fix(proxy_Server.py): initialize router redis cache if setup on proxy Fixes #6602 * test: add unit testing for new helper function
1 parent 16be203 commit 5d17016

File tree

5 files changed

+36
-2
lines changed

5 files changed

+36
-2
lines changed

Diff for: litellm/llms/nvidia_nim/embed.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_config(cls):
5858
def get_supported_openai_params(
5959
self,
6060
):
61-
return ["encoding_format", "user"]
61+
return ["encoding_format", "user", "dimensions"]
6262

6363
def map_openai_params(
6464
self,
@@ -73,6 +73,8 @@ def map_openai_params(
7373
optional_params["extra_body"].update({"input_type": v})
7474
elif k == "truncate":
7575
optional_params["extra_body"].update({"truncate": v})
76+
else:
77+
optional_params[k] = v
7678

7779
if kwargs is not None:
7880
# pass kwargs in extra_body

Diff for: litellm/proxy/proxy_server.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1631,7 +1631,7 @@ def _init_cache(
16311631
self,
16321632
cache_params: dict,
16331633
):
1634-
global redis_usage_cache
1634+
global redis_usage_cache, llm_router
16351635
from litellm import Cache
16361636

16371637
if "default_in_memory_ttl" in cache_params:
@@ -1646,6 +1646,10 @@ def _init_cache(
16461646
## INIT PROXY REDIS USAGE CLIENT ##
16471647
redis_usage_cache = litellm.cache.cache
16481648

1649+
## INIT ROUTER REDIS CACHE ##
1650+
if llm_router is not None:
1651+
llm_router._update_redis_cache(cache=redis_usage_cache)
1652+
16491653
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
16501654
"""
16511655
Load config file

Diff for: litellm/router.py

+14
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,20 @@ def __init__( # noqa: PLR0915
573573
litellm.amoderation, call_type="moderation"
574574
)
575575

576+
def _update_redis_cache(self, cache: RedisCache):
577+
"""
578+
Update the redis cache for the router, if none set.
579+
580+
Allows proxy user to just do
581+
```yaml
582+
litellm_settings:
583+
cache: true
584+
```
585+
and caching to just work.
586+
"""
587+
if self.cache.redis_cache is None:
588+
self.cache.redis_cache = cache
589+
576590
def initialize_assistants_endpoint(self):
577591
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
578592
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)

Diff for: tests/llm_translation/test_nvidia_nim.py

+2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_embedding_nvidia_nim():
7777
model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
7878
input="What is the meaning of life?",
7979
input_type="passage",
80+
dimensions=1024,
8081
client=client,
8182
)
8283
except Exception as e:
@@ -87,3 +88,4 @@ def test_embedding_nvidia_nim():
8788
assert request_body["input"] == "What is the meaning of life?"
8889
assert request_body["model"] == "nvidia/nv-embedqa-e5-v5"
8990
assert request_body["extra_body"]["input_type"] == "passage"
91+
assert request_body["dimensions"] == 1024

Diff for: tests/local_testing/test_router_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,15 @@ def test_router_get_model_access_groups(potential_access_group, expected_result)
384384
model_access_group=potential_access_group
385385
)
386386
assert access_groups == expected_result
387+
388+
389+
def test_router_redis_cache():
390+
router = Router(
391+
model_list=[{"model_name": "gemini/*", "litellm_params": {"model": "gemini/*"}}]
392+
)
393+
394+
redis_cache = MagicMock()
395+
396+
router._update_redis_cache(cache=redis_cache)
397+
398+
assert router.cache.redis_cache == redis_cache

0 commit comments

Comments
 (0)