Skip to content

Commit 12ac414

Browse files
authored
(Feat) - Allow calling Nova models on /bedrock/invoke/ (#8397)
* add nova to BEDROCK_INVOKE_PROVIDERS_LITERAL * BedrockInvokeNovaRequest * nova + invoke config * add AmazonInvokeNovaConfig * AmazonInvokeNovaConfig * run transform_request for invoke/nova models * AmazonInvokeNovaConfig * rename invoke tests * fix linting error * TestBedrockInvokeNovaJson * TestBedrockInvokeNovaJson * add converse_chunk_parser * test_nova_invoke_remove_empty_system_messages * test_nova_invoke_streaming_chunk_parsing
1 parent 80b1300 commit 12ac414

File tree

7 files changed

+276
-31
lines changed

7 files changed

+276
-31
lines changed

Diff for: litellm/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def identify(event_details):
360360
"meta.llama3-2-90b-instruct-v1:0",
361361
]
362362
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
363-
"cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21"
363+
"cohere", "anthropic", "mistral", "amazon", "meta", "llama", "ai21", "nova"
364364
]
365365
####### COMPLETION MODELS ###################
366366
open_ai_chat_completion_models: List = []
@@ -863,6 +863,9 @@ def add_known_models():
863863
from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import (
864864
AmazonAI21Config,
865865
)
866+
from .llms.bedrock.chat.invoke_transformations.amazon_nova_transformation import (
867+
AmazonInvokeNovaConfig,
868+
)
866869
from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import (
867870
AmazonAnthropicConfig,
868871
)

Diff for: litellm/llms/bedrock/chat/invoke_handler.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1342,14 +1342,19 @@ def _chunk_parser(self, chunk_data: dict) -> GChunk:
13421342
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
13431343
is_finished = True
13441344
finish_reason = "stop"
1345-
######## converse bedrock.anthropic mappings ###############
1345+
######## /bedrock/converse mappings ###############
13461346
elif (
13471347
"contentBlockIndex" in chunk_data
13481348
or "stopReason" in chunk_data
13491349
or "metrics" in chunk_data
13501350
or "trace" in chunk_data
13511351
):
13521352
return self.converse_chunk_parser(chunk_data=chunk_data)
1353+
######### /bedrock/invoke nova mappings ###############
1354+
elif "contentBlockDelta" in chunk_data:
1355+
# when using /bedrock/invoke/nova, the chunk_data is nested under "contentBlockDelta"
1356+
_chunk_data = chunk_data.get("contentBlockDelta", None)
1357+
return self.converse_chunk_parser(chunk_data=_chunk_data)
13531358
######## bedrock.mistral mappings ###############
13541359
elif "outputs" in chunk_data:
13551360
if (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Handles transforming requests for `bedrock/invoke/{nova} models`
3+
4+
Inherits from `AmazonConverseConfig`
5+
6+
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
7+
"""
8+
9+
from typing import List
10+
11+
import litellm
12+
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
13+
from litellm.types.llms.openai import AllMessageValues
14+
15+
16+
class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
17+
"""
18+
Config for sending `nova` requests to `/bedrock/invoke/`
19+
"""
20+
21+
def __init__(self, **kwargs):
22+
super().__init__(**kwargs)
23+
24+
def transform_request(
25+
self,
26+
model: str,
27+
messages: List[AllMessageValues],
28+
optional_params: dict,
29+
litellm_params: dict,
30+
headers: dict,
31+
) -> dict:
32+
_transformed_nova_request = super().transform_request(
33+
model=model,
34+
messages=messages,
35+
optional_params=optional_params,
36+
litellm_params=litellm_params,
37+
headers=headers,
38+
)
39+
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
40+
**_transformed_nova_request
41+
)
42+
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
43+
bedrock_invoke_nova_request = self._filter_allowed_fields(
44+
_bedrock_invoke_nova_request
45+
)
46+
return bedrock_invoke_nova_request
47+
48+
def _filter_allowed_fields(
49+
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
50+
) -> dict:
51+
"""
52+
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
53+
"""
54+
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
55+
return {
56+
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
57+
}
58+
59+
def _remove_empty_system_messages(
60+
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
61+
) -> None:
62+
"""
63+
In-place remove empty `system` messages from the request.
64+
65+
/bedrock/invoke/ does not allow empty `system` messages.
66+
"""
67+
_system_message = bedrock_invoke_nova_request.get("system", None)
68+
if isinstance(_system_message, list) and len(_system_message) == 0:
69+
bedrock_invoke_nova_request.pop("system", None)
70+
return

Diff for: litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import httpx
99

1010
import litellm
11+
from litellm._logging import verbose_logger
1112
from litellm.litellm_core_utils.core_helpers import map_finish_reason
1213
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
1314
from litellm.litellm_core_utils.prompt_templates.factory import (
@@ -166,7 +167,7 @@ def sign_request(
166167

167168
return dict(request.headers)
168169

169-
def transform_request( # noqa: PLR0915
170+
def transform_request(
170171
self,
171172
model: str,
172173
messages: List[AllMessageValues],
@@ -224,6 +225,14 @@ def transform_request( # noqa: PLR0915
224225
litellm_params=litellm_params,
225226
headers=headers,
226227
)
228+
elif provider == "nova":
229+
return litellm.AmazonInvokeNovaConfig().transform_request(
230+
model=model,
231+
messages=messages,
232+
optional_params=optional_params,
233+
litellm_params=litellm_params,
234+
headers=headers,
235+
)
227236
elif provider == "ai21":
228237
## LOAD CONFIG
229238
config = litellm.AmazonAI21Config.get_config()
@@ -297,6 +306,10 @@ def transform_response( # noqa: PLR0915
297306
raise BedrockError(
298307
message=raw_response.text, status_code=raw_response.status_code
299308
)
309+
verbose_logger.debug(
310+
"bedrock invoke response % s",
311+
json.dumps(completion_response, indent=4, default=str),
312+
)
300313
provider = self.get_bedrock_invoke_provider(model)
301314
outputText: Optional[str] = None
302315
try:
@@ -322,6 +335,18 @@ def transform_response( # noqa: PLR0915
322335
api_key=api_key,
323336
json_mode=json_mode,
324337
)
338+
elif provider == "nova":
339+
return litellm.AmazonInvokeNovaConfig().transform_response(
340+
model=model,
341+
raw_response=raw_response,
342+
model_response=model_response,
343+
logging_obj=logging_obj,
344+
request_data=request_data,
345+
messages=messages,
346+
optional_params=optional_params,
347+
litellm_params=litellm_params,
348+
encoding=encoding,
349+
)
325350
elif provider == "ai21":
326351
outputText = (
327352
completion_response.get("completions")[0].get("data").get("text")
@@ -503,6 +528,7 @@ def get_bedrock_invoke_provider(
503528
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
504529
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
505530
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
531+
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
506532
"""
507533
if model.startswith("invoke/"):
508534
model = model.replace("invoke/", "", 1)
@@ -515,6 +541,10 @@ def get_bedrock_invoke_provider(
515541
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
516542
if provider is not None:
517543
return provider
544+
545+
# check if provider == "nova"
546+
if "nova" in model:
547+
return "nova"
518548
return None
519549

520550
@staticmethod

Diff for: litellm/types/llms/bedrock.py

+12
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,18 @@ class RequestObject(CommonRequestObject, total=False):
184184
messages: Required[List[MessageBlock]]
185185

186186

187+
class BedrockInvokeNovaRequest(TypedDict, total=False):
188+
"""
189+
Request object for sending `nova` requests to `/bedrock/invoke/`
190+
"""
191+
192+
messages: List[MessageBlock]
193+
inferenceConfig: InferenceConfig
194+
system: List[SystemContentBlock]
195+
toolConfig: ToolConfigBlock
196+
guardrailConfig: Optional[GuardrailConfigBlock]
197+
198+
187199
class GenericStreamingChunk(TypedDict):
188200
text: Required[str]
189201
tool_use: Optional[ChatCompletionToolCallChunk]

Diff for: tests/llm_translation/test_bedrock_invoke_claude_json.py

-28
This file was deleted.

Diff for: tests/llm_translation/test_bedrock_invoke_tests.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from base_llm_unit_tests import BaseLLMChatTest
2+
import pytest
3+
import sys
4+
import os
5+
6+
7+
sys.path.insert(
8+
0, os.path.abspath("../..")
9+
) # Adds the parent directory to the system path
10+
import litellm
11+
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
12+
13+
14+
class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
15+
def get_base_completion_call_args(self) -> dict:
16+
litellm._turn_on_debug()
17+
return {
18+
"model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
19+
}
20+
21+
def test_tool_call_no_arguments(self, tool_call_no_arguments):
22+
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
23+
pass
24+
25+
@pytest.fixture(autouse=True)
26+
def skip_non_json_tests(self, request):
27+
if not "json" in request.function.__name__.lower():
28+
pytest.skip(
29+
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
30+
)
31+
32+
33+
class TestBedrockInvokeNovaJson(BaseLLMChatTest):
34+
def get_base_completion_call_args(self) -> dict:
35+
litellm._turn_on_debug()
36+
return {
37+
"model": "bedrock/invoke/us.amazon.nova-micro-v1:0",
38+
}
39+
40+
def test_tool_call_no_arguments(self, tool_call_no_arguments):
41+
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
42+
pass
43+
44+
@pytest.fixture(autouse=True)
45+
def skip_non_json_tests(self, request):
46+
if not "json" in request.function.__name__.lower():
47+
pytest.skip(
48+
f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
49+
)
50+
51+
52+
def test_nova_invoke_remove_empty_system_messages():
53+
"""Test that _remove_empty_system_messages removes empty system list."""
54+
input_request = BedrockInvokeNovaRequest(
55+
messages=[{"content": [{"text": "Hello"}], "role": "user"}],
56+
system=[],
57+
inferenceConfig={"temperature": 0.7},
58+
)
59+
60+
litellm.AmazonInvokeNovaConfig()._remove_empty_system_messages(input_request)
61+
62+
assert "system" not in input_request
63+
assert "messages" in input_request
64+
assert "inferenceConfig" in input_request
65+
66+
67+
def test_nova_invoke_filter_allowed_fields():
68+
"""
69+
Test that _filter_allowed_fields only keeps fields defined in BedrockInvokeNovaRequest.
70+
71+
Nova Invoke does not allow `additionalModelRequestFields` and `additionalModelResponseFieldPaths` in the request body.
72+
This test ensures that these fields are not included in the request body.
73+
"""
74+
_input_request = {
75+
"messages": [{"content": [{"text": "Hello"}], "role": "user"}],
76+
"system": [{"text": "System prompt"}],
77+
"inferenceConfig": {"temperature": 0.7},
78+
"additionalModelRequestFields": {"this": "should be removed"},
79+
"additionalModelResponseFieldPaths": ["this", "should", "be", "removed"],
80+
}
81+
82+
input_request = BedrockInvokeNovaRequest(**_input_request)
83+
84+
result = litellm.AmazonInvokeNovaConfig()._filter_allowed_fields(input_request)
85+
86+
assert "additionalModelRequestFields" not in result
87+
assert "additionalModelResponseFieldPaths" not in result
88+
assert "messages" in result
89+
assert "system" in result
90+
assert "inferenceConfig" in result
91+
92+
93+
def test_nova_invoke_streaming_chunk_parsing():
94+
"""
95+
Test that the AWSEventStreamDecoder correctly handles Nova's /bedrock/invoke/ streaming format
96+
where content is nested under 'contentBlockDelta'.
97+
"""
98+
from litellm.llms.bedrock.chat.invoke_handler import AWSEventStreamDecoder
99+
100+
# Initialize the decoder with a Nova model
101+
decoder = AWSEventStreamDecoder(model="bedrock/invoke/us.amazon.nova-micro-v1:0")
102+
103+
# Test case 1: Text content in contentBlockDelta
104+
nova_text_chunk = {
105+
"contentBlockDelta": {
106+
"delta": {"text": "Hello, how can I help?"},
107+
"contentBlockIndex": 0,
108+
}
109+
}
110+
result = decoder._chunk_parser(nova_text_chunk)
111+
assert result["text"] == "Hello, how can I help?"
112+
assert result["index"] == 0
113+
assert not result["is_finished"]
114+
assert result["tool_use"] is None
115+
116+
# Test case 2: Tool use start in contentBlockDelta
117+
nova_tool_start_chunk = {
118+
"contentBlockDelta": {
119+
"start": {"toolUse": {"name": "get_weather", "toolUseId": "tool_1"}},
120+
"contentBlockIndex": 1,
121+
}
122+
}
123+
result = decoder._chunk_parser(nova_tool_start_chunk)
124+
assert result["text"] == ""
125+
assert result["index"] == 1
126+
assert result["tool_use"] is not None
127+
assert result["tool_use"]["type"] == "function"
128+
assert result["tool_use"]["function"]["name"] == "get_weather"
129+
assert result["tool_use"]["id"] == "tool_1"
130+
131+
# Test case 3: Tool use arguments in contentBlockDelta
132+
nova_tool_args_chunk = {
133+
"contentBlockDelta": {
134+
"delta": {"toolUse": {"input": '{"location": "New York"}'}},
135+
"contentBlockIndex": 2,
136+
}
137+
}
138+
result = decoder._chunk_parser(nova_tool_args_chunk)
139+
assert result["text"] == ""
140+
assert result["index"] == 2
141+
assert result["tool_use"] is not None
142+
assert result["tool_use"]["function"]["arguments"] == '{"location": "New York"}'
143+
144+
# Test case 4: Stop reason in contentBlockDelta
145+
nova_stop_chunk = {
146+
"contentBlockDelta": {
147+
"stopReason": "tool_use",
148+
}
149+
}
150+
result = decoder._chunk_parser(nova_stop_chunk)
151+
print(result)
152+
assert result["is_finished"] is True
153+
assert result["finish_reason"] == "tool_calls"

0 commit comments

Comments
 (0)