Skip to content

Commit 8809360

Browse files
authored
(Bug Fix) - Bedrock completions with aws_region_name (#8384)
* test_bedrock_completion_with_region_name * test_bedrock_base_model_helper * test_bedrock_base_model_helper * fix aws_bedrock_runtime_endpoint * test_dynamic_aws_params_propagation * test_dynamic_aws_params_propagation
1 parent 64ccf4c commit 8809360

File tree

4 files changed

+310
-11
lines changed

4 files changed

+310
-11
lines changed

litellm/llms/bedrock/base_aws_llm.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self) -> None:
5252
"aws_role_name",
5353
"aws_web_identity_token",
5454
"aws_sts_endpoint",
55+
"aws_bedrock_runtime_endpoint",
5556
]
5657

5758
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:

litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def get_complete_url(
8787
optional_params=optional_params,
8888
)
8989
### SET RUNTIME ENDPOINT ###
90-
aws_bedrock_runtime_endpoint = optional_params.pop(
90+
aws_bedrock_runtime_endpoint = optional_params.get(
9191
"aws_bedrock_runtime_endpoint", None
9292
) # https://bedrock-runtime.{region_name}.amazonaws.com
9393
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
@@ -125,15 +125,15 @@ def sign_request(
125125

126126
## CREDENTIALS ##
127127
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
128-
extra_headers = optional_params.pop("extra_headers", None)
129-
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
130-
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
131-
aws_session_token = optional_params.pop("aws_session_token", None)
132-
aws_role_name = optional_params.pop("aws_role_name", None)
133-
aws_session_name = optional_params.pop("aws_session_name", None)
134-
aws_profile_name = optional_params.pop("aws_profile_name", None)
135-
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
136-
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
128+
extra_headers = optional_params.get("extra_headers", None)
129+
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
130+
aws_access_key_id = optional_params.get("aws_access_key_id", None)
131+
aws_session_token = optional_params.get("aws_session_token", None)
132+
aws_role_name = optional_params.get("aws_role_name", None)
133+
aws_session_name = optional_params.get("aws_session_name", None)
134+
aws_profile_name = optional_params.get("aws_profile_name", None)
135+
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
136+
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
137137
aws_region_name = self._get_aws_region_name(optional_params)
138138

139139
credentials: Credentials = self.get_credentials(
@@ -588,7 +588,7 @@ def _get_aws_region_name(self, optional_params: dict) -> str:
588588
"""
589589
Get the AWS region name from the environment variables
590590
"""
591-
aws_region_name = optional_params.pop("aws_region_name", None)
591+
aws_region_name = optional_params.get("aws_region_name", None)
592592
### SET REGION NAME ###
593593
if aws_region_name is None:
594594
# check env #

tests/llm_translation/test_bedrock_completion.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
load_dotenv()
1515
import io
1616
import os
17+
import json
1718

1819
sys.path.insert(
1920
0, os.path.abspath("../..")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
# tests/llm_translation/test_base_aws_llm.py
2+
import os
3+
import json
4+
import pytest
5+
from unittest.mock import patch
6+
from botocore.credentials import Credentials
7+
import sys
8+
9+
sys.path.insert(
10+
0, os.path.abspath("../..")
11+
) # Adds the parent directory to the system path
12+
13+
import litellm
14+
from litellm.llms.custom_httpx.http_handler import HTTPHandler
15+
from unittest.mock import Mock
16+
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
17+
18+
import json
19+
import pytest
20+
from unittest.mock import patch, Mock
21+
22+
import litellm
23+
from litellm.llms.custom_httpx.http_handler import HTTPHandler
24+
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
25+
26+
27+
def test_bedrock_completion_with_region_name():
28+
litellm._turn_on_debug()
29+
client = HTTPHandler()
30+
31+
with patch.object(client, "post") as mock_post:
32+
mock_response = Mock()
33+
# Construct a response similar to our other tests.
34+
mock_response.text = json.dumps(
35+
{
36+
"response_id": "379ed018/60744aff-e741-4aad-bd10-74639a4ade79",
37+
"text": "Hello! How's it going? I hope you're having a fantastic day!",
38+
"generation_id": "38709bb9-f20f-42d9-9c61-13a73b7bbc12",
39+
"chat_history": [
40+
{"role": "USER", "message": "Hello, world!"},
41+
{
42+
"role": "CHATBOT",
43+
"message": "Hello! How's it going? I hope you're having a fantastic day!",
44+
},
45+
],
46+
"finish_reason": "COMPLETE",
47+
}
48+
)
49+
mock_response.status_code = 200
50+
mock_response.headers = {"Content-Type": "application/json"}
51+
mock_response.json = lambda: json.loads(mock_response.text)
52+
mock_post.return_value = mock_response
53+
54+
# Pass the client so that the HTTP call will be intercepted.
55+
response = litellm.completion(
56+
model="cohere.command-r-v1:0",
57+
messages=[{"role": "user", "content": "Hello, world!"}],
58+
aws_region_name="us-west-12",
59+
client=client,
60+
)
61+
62+
# Ensure our post method has been called.
63+
mock_post.assert_called_once()
64+
65+
assert (
66+
mock_post.call_args.kwargs["url"]
67+
== "https://bedrock-runtime.us-west-12.amazonaws.com/model/cohere.command-r-v1:0/invoke"
68+
)
69+
assert (
70+
mock_post.call_args.kwargs["data"]
71+
== '{"message": "Hello, world!", "chat_history": []}'
72+
)
73+
74+
# Print the URL and body of the HTTP request.
75+
# assert request was signed with the correct region
76+
_authorization_header = mock_post.call_args.kwargs["headers"]["Authorization"]
77+
import re
78+
79+
# Ensure the authorization header contains the exact region segment "us-west-12/bedrock/aws4_request"
80+
pattern = r"us-west-12/bedrock/aws4_request"
81+
assert re.search(pattern, _authorization_header) is not None
82+
83+
84+
def test_bedrock_completion_with_dynamic_authentication_params():
85+
litellm._turn_on_debug()
86+
client = HTTPHandler()
87+
88+
with patch.object(client, "post") as mock_post:
89+
mock_response = Mock()
90+
# Construct a response similar to our other tests.
91+
mock_response.text = json.dumps(
92+
{
93+
"response_id": "379ed018/60744aff-e741-4aad-bd10-74639a4ade79",
94+
"text": "Hello! How's it going? I hope you're having a fantastic day!",
95+
"generation_id": "38709bb9-f20f-42d9-9c61-13a73b7bbc12",
96+
"chat_history": [
97+
{"role": "USER", "message": "Hello, world!"},
98+
{
99+
"role": "CHATBOT",
100+
"message": "Hello! How's it going? I hope you're having a fantastic day!",
101+
},
102+
],
103+
"finish_reason": "COMPLETE",
104+
}
105+
)
106+
mock_response.status_code = 200
107+
mock_response.headers = {"Content-Type": "application/json"}
108+
mock_response.json = lambda: json.loads(mock_response.text)
109+
mock_post.return_value = mock_response
110+
111+
# Pass the client so that the HTTP call will be intercepted.
112+
response = litellm.completion(
113+
model="cohere.command-r-v1:0",
114+
messages=[{"role": "user", "content": "Hello, world!"}],
115+
aws_access_key_id="dynamically_generated_access_key_id",
116+
aws_secret_access_key="dynamically_generated_secret_access_key",
117+
client=client,
118+
)
119+
120+
# Ensure our post method has been called.
121+
mock_post.assert_called_once()
122+
import re
123+
124+
# Get authorization header
125+
_authorization_header = mock_post.call_args.kwargs["headers"]["Authorization"]
126+
127+
# Check for exact credential pattern
128+
pattern = r"AWS4-HMAC-SHA256 Credential=dynamically_generated_access_key_id/\d{8}/[a-z0-9-]+/bedrock/aws4_request"
129+
assert re.search(pattern, _authorization_header) is not None
130+
131+
132+
def test_bedrock_completion_with_dynamic_bedrock_runtime_endpoint():
133+
litellm._turn_on_debug()
134+
client = HTTPHandler()
135+
136+
with patch.object(client, "post") as mock_post:
137+
mock_response = Mock()
138+
# Construct a response similar to our other tests.
139+
mock_response.text = json.dumps(
140+
{
141+
"response_id": "379ed018/60744aff-e741-4aad-bd10-74639a4ade79",
142+
"text": "Hello! How's it going? I hope you're having a fantastic day!",
143+
"generation_id": "38709bb9-f20f-42d9-9c61-13a73b7bbc12",
144+
"chat_history": [
145+
{"role": "USER", "message": "Hello, world!"},
146+
{
147+
"role": "CHATBOT",
148+
"message": "Hello! How's it going? I hope you're having a fantastic day!",
149+
},
150+
],
151+
"finish_reason": "COMPLETE",
152+
}
153+
)
154+
mock_response.status_code = 200
155+
mock_response.headers = {"Content-Type": "application/json"}
156+
mock_response.json = lambda: json.loads(mock_response.text)
157+
mock_post.return_value = mock_response
158+
159+
# Pass the client so that the HTTP call will be intercepted.
160+
response = litellm.completion(
161+
model="cohere.command-r-v1:0",
162+
messages=[{"role": "user", "content": "Hello, world!"}],
163+
aws_bedrock_runtime_endpoint="https://my-fake-endpoint.com",
164+
client=client,
165+
)
166+
167+
# Ensure our post method has been called.
168+
mock_post.assert_called_once()
169+
assert (
170+
mock_post.call_args.kwargs["url"]
171+
== "https://my-fake-endpoint.com/model/cohere.command-r-v1:0/invoke"
172+
)
173+
174+
175+
# ------------------------------------------------------------------------------
176+
# A dummy credentials object to return from get_credentials.
177+
# (It must have attributes so that SigV4Auth.add_auth doesn't break.)
178+
# ------------------------------------------------------------------------------
179+
class DummyCredentials:
180+
access_key = "dummy_access"
181+
secret_key = "dummy_secret"
182+
token = "dummy_token"
183+
184+
185+
# ------------------------------------------------------------------------------
186+
# This test makes sure that a given dynamic parameter is passed into the call
187+
# to BaseAWSLLM.get_credentials. (Some dynamic params—for example aws_region_name
188+
# or aws_bedrock_runtime_endpoint—are already covered by other tests.)
189+
# ------------------------------------------------------------------------------
190+
@pytest.mark.parametrize(
191+
"model",
192+
[
193+
"bedrock/converse/cohere.command-r-v1:0",
194+
"cohere.command-r-v1:0",
195+
"bedrock/cohere.command-r-v1:0",
196+
"bedrock/invoke/cohere.command-r-v1:0",
197+
],
198+
)
199+
@pytest.mark.parametrize(
200+
"param_name, param_value",
201+
[
202+
("aws_session_token", "dummy_session_token"),
203+
("aws_session_name", "dummy_session_name"),
204+
("aws_profile_name", "dummy_profile_name"),
205+
("aws_role_name", "dummy_role_name"),
206+
("aws_web_identity_token", "dummy_web_identity_token"),
207+
("aws_sts_endpoint", "dummy_sts_endpoint"),
208+
],
209+
)
210+
def test_dynamic_aws_params_propagation(model, param_name, param_value):
211+
"""
212+
When passed to litellm.completion, each dynamic AWS authentication parameter
213+
should propagate down to the get_credentials() call in BaseAWSLLM.
214+
215+
Also tests different model parameter values.
216+
"""
217+
client = HTTPHandler()
218+
219+
# Base parameters required for the completion call.
220+
# (We include aws_access_key_id and aws_secret_access_key so that the correct auth
221+
# branch in get_credentials() is reached.)
222+
base_params = {
223+
"model": model,
224+
"messages": [{"role": "user", "content": "Hello, world!"}],
225+
"aws_access_key_id": "dummy_access",
226+
"aws_secret_access_key": "dummy_secret",
227+
"client": client,
228+
}
229+
# For parameters such as aws_role_name or aws_web_identity_token a session name is required.
230+
if param_name in ("aws_role_name", "aws_web_identity_token"):
231+
base_params["aws_session_name"] = "dummy_session_name"
232+
if param_name == "aws_web_identity_token":
233+
# The web identity branch also requires a role name.
234+
base_params["aws_role_name"] = "dummy_role_name"
235+
# Inject the dynamic parameter under test.
236+
base_params[param_name] = param_value
237+
238+
# Patch SigV4Auth in the signing (so that no actual signing is done).
239+
with patch("botocore.auth.SigV4Auth", autospec=True) as mock_sigv4:
240+
instance = mock_sigv4.return_value
241+
instance.add_auth.return_value = None
242+
243+
# Patch BaseAWSLLM.get_credentials so that we can capture its kwargs.
244+
def dummy_get_credentials(**kwargs):
245+
dummy_get_credentials.called_kwargs = kwargs # type: ignore[attr-defined]
246+
return DummyCredentials()
247+
248+
with patch.object(
249+
BaseAWSLLM, "get_credentials", side_effect=dummy_get_credentials
250+
):
251+
# Patch the HTTP client's post method to avoid an actual HTTP call.
252+
with patch.object(client, "post") as mock_post:
253+
mock_response = Mock()
254+
mock_response.text = json.dumps(
255+
{
256+
"response_id": "dummy_response",
257+
"text": "Hello! world",
258+
"generation_id": "dummy_gen",
259+
"chat_history": [],
260+
"finish_reason": "COMPLETE",
261+
}
262+
)
263+
if "converse" in model:
264+
mock_response.text = json.dumps(
265+
{
266+
"output": {
267+
"message": {
268+
"role": "assistant",
269+
"content": [{"text": "Here's a joke..."}],
270+
}
271+
},
272+
"usage": {
273+
"inputTokens": 12,
274+
"outputTokens": 6,
275+
"totalTokens": 18,
276+
},
277+
"stopReason": "stop",
278+
}
279+
)
280+
281+
mock_response.status_code = 200
282+
mock_response.headers = {"Content-Type": "application/json"}
283+
mock_response.json = lambda: json.loads(mock_response.text)
284+
mock_post.return_value = mock_response
285+
286+
# Call litellm.completion with our base & dynamic parameters.
287+
litellm.completion(**base_params)
288+
289+
print(
290+
"get_credentials.called_kwargs",
291+
json.dumps(dummy_get_credentials.called_kwargs, indent=4),
292+
)
293+
294+
# We now assert that get_credentials() was called with the dynamic param.
295+
assert (
296+
dummy_get_credentials.called_kwargs.get(param_name) == param_value
297+
)

0 commit comments

Comments
 (0)