Skip to content

Commit 2b55c68

Browse files
chore(internal): improve bedrock streaming setup (anthropics#354)
1 parent 2778a22 commit 2b55c68

File tree

5 files changed

+47
-85
lines changed

5 files changed

+47
-85
lines changed

src/anthropic/_base_client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
RAW_RESPONSE_HEADER,
8080
OVERRIDE_CAST_TO_HEADER,
8181
)
82-
from ._streaming import Stream, AsyncStream
82+
from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
8383
from ._exceptions import (
8484
APIStatusError,
8585
APITimeoutError,
@@ -431,6 +431,9 @@ def _prepare_url(self, url: str) -> URL:
431431

432432
return merge_url
433433

434+
def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
435+
return SSEDecoder()
436+
434437
def _build_request(
435438
self,
436439
options: FinalRequestOptions,

src/anthropic/_streaming.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import inspect
66
from types import TracebackType
77
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
8-
from typing_extensions import Self, TypeGuard, override, get_origin
8+
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
99

1010
import httpx
1111

@@ -23,6 +23,8 @@ class Stream(Generic[_T]):
2323

2424
response: httpx.Response
2525

26+
_decoder: SSEDecoder | SSEBytesDecoder
27+
2628
def __init__(
2729
self,
2830
*,
@@ -33,7 +35,7 @@ def __init__(
3335
self.response = response
3436
self._cast_to = cast_to
3537
self._client = client
36-
self._decoder = SSEDecoder()
38+
self._decoder = client._make_sse_decoder()
3739
self._iterator = self.__stream__()
3840

3941
def __next__(self) -> _T:
@@ -44,7 +46,10 @@ def __iter__(self) -> Iterator[_T]:
4446
yield item
4547

4648
def _iter_events(self) -> Iterator[ServerSentEvent]:
47-
yield from self._decoder.iter(self.response.iter_lines())
49+
if isinstance(self._decoder, SSEBytesDecoder):
50+
yield from self._decoder.iter_bytes(self.response.iter_bytes())
51+
else:
52+
yield from self._decoder.iter(self.response.iter_lines())
4853

4954
def __stream__(self) -> Iterator[_T]:
5055
cast_to = cast(Any, self._cast_to)
@@ -117,6 +122,8 @@ class AsyncStream(Generic[_T]):
117122

118123
response: httpx.Response
119124

125+
_decoder: SSEDecoder | SSEBytesDecoder
126+
120127
def __init__(
121128
self,
122129
*,
@@ -127,7 +134,7 @@ def __init__(
127134
self.response = response
128135
self._cast_to = cast_to
129136
self._client = client
130-
self._decoder = SSEDecoder()
137+
self._decoder = client._make_sse_decoder()
131138
self._iterator = self.__stream__()
132139

133140
async def __anext__(self) -> _T:
@@ -138,8 +145,12 @@ async def __aiter__(self) -> AsyncIterator[_T]:
138145
yield item
139146

140147
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
141-
async for sse in self._decoder.aiter(self.response.aiter_lines()):
142-
yield sse
148+
if isinstance(self._decoder, SSEBytesDecoder):
149+
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
150+
yield sse
151+
else:
152+
async for sse in self._decoder.aiter(self.response.aiter_lines()):
153+
yield sse
143154

144155
async def __stream__(self) -> AsyncIterator[_T]:
145156
cast_to = cast(Any, self._cast_to)
@@ -325,6 +336,17 @@ def decode(self, line: str) -> ServerSentEvent | None:
325336
return None
326337

327338

339+
@runtime_checkable
340+
class SSEBytesDecoder(Protocol):
341+
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
342+
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
343+
...
344+
345+
def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
346+
"""Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
347+
...
348+
349+
328350
def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
329351
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
330352
origin = get_origin(typ) or typ

src/anthropic/lib/bedrock/_client.py

+11-58
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22

33
import os
44
from typing import Any, Union, Mapping, TypeVar
5-
from typing_extensions import override, get_origin
5+
from typing_extensions import override
66

77
import httpx
88

99
from ... import _exceptions
10-
from ._stream import BedrockStream, AsyncBedrockStream
11-
from ..._types import NOT_GIVEN, NotGiven, ResponseT
10+
from ..._types import NOT_GIVEN, NotGiven
1211
from ..._utils import is_dict
1312
from ..._version import __version__
14-
from ..._response import extract_stream_chunk_type
1513
from ..._streaming import Stream, AsyncStream
1614
from ..._exceptions import APIStatusError
1715
from ..._base_client import DEFAULT_MAX_RETRIES, BaseClient, SyncAPIClient, AsyncAPIClient, FinalRequestOptions
16+
from ._stream_decoder import AWSEventStreamDecoder
1817
from ...resources.completions import Completions, AsyncCompletions
1918

2019
DEFAULT_VERSION = "bedrock-2023-05-31"
@@ -131,10 +130,12 @@ def __init__(
131130
_strict_response_validation=_strict_response_validation,
132131
)
133132

134-
self._default_stream_cls = BedrockStream
135-
136133
self.completions = Completions(self)
137134

135+
@override
136+
def _make_sse_decoder(self) -> AWSEventStreamDecoder:
137+
return AWSEventStreamDecoder()
138+
138139
@override
139140
def _prepare_request(self, request: httpx.Request) -> None:
140141
from ._auth import get_auth_headers
@@ -153,31 +154,6 @@ def _prepare_request(self, request: httpx.Request) -> None:
153154
)
154155
request.headers.update(headers)
155156

156-
@override
157-
def _process_response(
158-
self,
159-
*,
160-
cast_to: type[ResponseT],
161-
options: FinalRequestOptions,
162-
response: httpx.Response,
163-
stream: bool,
164-
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
165-
) -> ResponseT:
166-
if stream_cls is not None and get_origin(stream_cls) == Stream:
167-
chunk_type = extract_stream_chunk_type(stream_cls)
168-
169-
# the type: ignore is required as mypy doesn't like us
170-
# dynamically created a concrete type like this
171-
stream_cls = BedrockStream[chunk_type] # type: ignore
172-
173-
return super()._process_response(
174-
cast_to=cast_to,
175-
options=options,
176-
response=response,
177-
stream=stream,
178-
stream_cls=stream_cls,
179-
)
180-
181157

182158
class AsyncAnthropicBedrock(BaseBedrockClient[httpx.AsyncClient, AsyncStream[Any]], AsyncAPIClient):
183159
completions: AsyncCompletions
@@ -231,10 +207,12 @@ def __init__(
231207
_strict_response_validation=_strict_response_validation,
232208
)
233209

234-
self._default_stream_cls = AsyncBedrockStream
235-
236210
self.completions = AsyncCompletions(self)
237211

212+
@override
213+
def _make_sse_decoder(self) -> AWSEventStreamDecoder:
214+
return AWSEventStreamDecoder()
215+
238216
@override
239217
async def _prepare_request(self, request: httpx.Request) -> None:
240218
from ._auth import get_auth_headers
@@ -252,28 +230,3 @@ async def _prepare_request(self, request: httpx.Request) -> None:
252230
data=data,
253231
)
254232
request.headers.update(headers)
255-
256-
@override
257-
async def _process_response(
258-
self,
259-
*,
260-
cast_to: type[ResponseT],
261-
options: FinalRequestOptions,
262-
response: httpx.Response,
263-
stream: bool,
264-
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
265-
) -> ResponseT:
266-
if stream_cls is not None and get_origin(stream_cls) == AsyncStream:
267-
chunk_type = extract_stream_chunk_type(stream_cls)
268-
269-
# the type: ignore is required as mypy doesn't like us
270-
# dynamically created a concrete type like this
271-
stream_cls = AsyncBedrockStream[chunk_type] # type: ignore
272-
273-
return await super()._process_response(
274-
cast_to=cast_to,
275-
options=options,
276-
response=response,
277-
stream=stream,
278-
stream_cls=stream_cls,
279-
)

src/anthropic/lib/bedrock/_stream.py

+2-18
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
from __future__ import annotations
22

3-
from typing import TypeVar, Iterator
4-
from typing_extensions import AsyncIterator, override
3+
from typing import TypeVar
54

65
import httpx
76

87
from ..._client import Anthropic, AsyncAnthropic
9-
from ..._streaming import Stream, AsyncStream, ServerSentEvent
8+
from ..._streaming import Stream, AsyncStream
109
from ._stream_decoder import AWSEventStreamDecoder
1110

1211
_T = TypeVar("_T")
1312

1413

1514
class BedrockStream(Stream[_T]):
16-
# the AWS decoder expects `bytes` instead of `str`
17-
_decoder: AWSEventStreamDecoder # type: ignore
18-
1915
def __init__(
2016
self,
2117
*,
@@ -27,15 +23,8 @@ def __init__(
2723

2824
self._decoder = AWSEventStreamDecoder()
2925

30-
@override
31-
def _iter_events(self) -> Iterator[ServerSentEvent]:
32-
yield from self._decoder.iter(self.response.iter_bytes())
33-
3426

3527
class AsyncBedrockStream(AsyncStream[_T]):
36-
# the AWS decoder expects `bytes` instead of `str`
37-
_decoder: AWSEventStreamDecoder # type: ignore
38-
3928
def __init__(
4029
self,
4130
*,
@@ -46,8 +35,3 @@ def __init__(
4635
super().__init__(cast_to=cast_to, response=response, client=client)
4736

4837
self._decoder = AWSEventStreamDecoder()
49-
50-
@override
51-
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
52-
async for sse in self._decoder.aiter(self.response.aiter_bytes()):
53-
yield sse

src/anthropic/lib/bedrock/_stream_decoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self) -> None:
2727

2828
self.parser = EventStreamJSONParser()
2929

30-
def iter(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
30+
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
3131
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
3232
from botocore.eventstream import EventStreamBuffer
3333

@@ -39,7 +39,7 @@ def iter(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
3939
if message:
4040
yield ServerSentEvent(data=message, event="completion")
4141

42-
async def aiter(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
42+
async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
4343
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
4444
from botocore.eventstream import EventStreamBuffer
4545

0 commit comments

Comments
 (0)