Skip to content

Commit 28e5812

Browse files
committed
pass bedrock_invoke_provider to make_call
1 parent c7c4107 commit 28e5812

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

litellm/llms/bedrock/chat/invoke_handler.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
parse_xml_params,
4141
prompt_factory,
4242
)
43+
from litellm.llms.anthropic.chat.handler import (
44+
ModelResponseIterator as AnthropicModelResponseIterator,
45+
)
4346
from litellm.llms.custom_httpx.http_handler import (
4447
AsyncHTTPHandler,
4548
HTTPHandler,
@@ -177,6 +180,7 @@ async def make_call(
177180
logging_obj: Logging,
178181
fake_stream: bool = False,
179182
json_mode: Optional[bool] = False,
183+
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
180184
):
181185
try:
182186
if client is None:
@@ -214,6 +218,14 @@ async def make_call(
214218
completion_stream: Any = MockResponseIterator(
215219
model_response=model_response, json_mode=json_mode
216220
)
221+
elif bedrock_invoke_provider == "anthropic":
222+
decoder = AmazonAnthropicClaudeStreamDecoder(
223+
model=model,
224+
sync_stream=False,
225+
)
226+
completion_stream = decoder.aiter_bytes(
227+
response.aiter_bytes(chunk_size=1024)
228+
)
217229
else:
218230
decoder = AWSEventStreamDecoder(model=model)
219231
completion_stream = decoder.aiter_bytes(
@@ -248,6 +260,7 @@ def make_sync_call(
248260
logging_obj: Logging,
249261
fake_stream: bool = False,
250262
json_mode: Optional[bool] = False,
263+
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
251264
):
252265
try:
253266
if client is None:
@@ -283,6 +296,12 @@ def make_sync_call(
283296
completion_stream: Any = MockResponseIterator(
284297
model_response=model_response, json_mode=json_mode
285298
)
299+
elif bedrock_invoke_provider == "anthropic":
300+
decoder = AmazonAnthropicClaudeStreamDecoder(
301+
model=model,
302+
sync_stream=True,
303+
)
304+
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
286305
else:
287306
decoder = AWSEventStreamDecoder(model=model)
288307
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
@@ -1323,7 +1342,7 @@ def _chunk_parser(self, chunk_data: dict) -> GChunk:
13231342
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
13241343
is_finished = True
13251344
finish_reason = "stop"
1326-
######## bedrock.anthropic mappings ###############
1345+
######## converse bedrock.anthropic mappings ###############
13271346
elif (
13281347
"contentBlockIndex" in chunk_data
13291348
or "stopReason" in chunk_data
@@ -1429,6 +1448,22 @@ def _parse_message_from_event(self, event) -> Optional[str]:
14291448
return chunk.decode() # type: ignore[no-any-return]
14301449

14311450

1451+
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
1452+
def __init__(
1453+
self,
1454+
model: str,
1455+
sync_stream: bool,
1456+
) -> None:
1457+
super().__init__(model=model)
1458+
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
1459+
streaming_response=None,
1460+
sync_stream=sync_stream,
1461+
)
1462+
1463+
def _chunk_parser(self, chunk_data: dict) -> GChunk:
1464+
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)
1465+
1466+
14321467
class MockResponseIterator: # for returning ai21 streaming responses
14331468
def __init__(self, model_response, json_mode: Optional[bool] = False):
14341469
self.model_response = model_response

litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py

-16
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,3 @@ def transform_response(
8383
api_key=api_key,
8484
json_mode=json_mode,
8585
)
86-
87-
88-
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
89-
def __init__(
90-
self,
91-
model: str,
92-
sync_stream: bool,
93-
) -> None:
94-
super().__init__(model=model)
95-
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
96-
streaming_response=None,
97-
sync_stream=sync_stream,
98-
)
99-
100-
def _chunk_parser(self, chunk_data: dict) -> GChunk:
101-
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)

litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py

+2
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def get_async_custom_stream_wrapper(
440440
messages=messages,
441441
logging_obj=logging_obj,
442442
fake_stream=True if "ai21" in api_base else False,
443+
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
443444
),
444445
model=model,
445446
custom_llm_provider="bedrock",
@@ -473,6 +474,7 @@ def get_sync_custom_stream_wrapper(
473474
messages=messages,
474475
logging_obj=logging_obj,
475476
fake_stream=True if "ai21" in api_base else False,
477+
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
476478
),
477479
model=model,
478480
custom_llm_provider="bedrock",

0 commit comments

Comments
 (0)