Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/anthropic/lib/bedrock/_stream_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
yield ServerSentEvent(data=message, event="completion")
sse = self._parse_message_from_event(event)
if sse is not None:
yield sse

async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
Expand All @@ -47,18 +47,21 @@ async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[Ser
async for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
yield ServerSentEvent(data=message, event="completion")
sse = self._parse_message_from_event(event)
if sse is not None:
yield sse

def _parse_message_from_event(self, event: EventStreamMessage) -> str | None:
def _parse_message_from_event(self, event: EventStreamMessage) -> ServerSentEvent | None:
response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
error_body = response_dict.get("body", b"")
if isinstance(error_body, bytes):
error_body = error_body.decode("utf-8", errors="replace")
return ServerSentEvent(data=error_body, event="error")

chunk = parsed_response.get("chunk")
if not chunk:
return None

return chunk.get("bytes").decode() # type: ignore[no-any-return]
return ServerSentEvent(data=chunk.get("bytes").decode(), event="completion") # type: ignore[arg-type]
55 changes: 55 additions & 0 deletions tests/lib/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,58 @@ def test_region_infer_from_specified_profile(
client = AnthropicBedrock()

assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"]


class TestAWSEventStreamDecoder:
def _make_event(self, response_dict: dict) -> t.Any:
"""Create a mock EventStreamMessage that returns the given response_dict."""

class MockEvent:
def __init__(self, resp: dict) -> None:
self._resp = resp

def to_response_dict(self) -> dict:
return self._resp

return MockEvent(response_dict)

def test_non_200_status_emits_error_sse(self) -> None:
from anthropic.lib.bedrock._stream_decoder import AWSEventStreamDecoder

decoder = AWSEventStreamDecoder()

error_body = b'{"message":"The system encountered an unexpected error during processing. Try your request again."}'
event = self._make_event(
{
"status_code": 400,
"headers": {
":exception-type": "internalServerException",
":content-type": "application/json",
":message-type": "exception",
},
"body": error_body,
}
)

sse = decoder._parse_message_from_event(event)
assert sse is not None
assert sse.event == "error"
assert sse.data == error_body.decode("utf-8")

def test_non_200_status_no_body(self) -> None:
from anthropic.lib.bedrock._stream_decoder import AWSEventStreamDecoder

decoder = AWSEventStreamDecoder()

event = self._make_event(
{
"status_code": 500,
"headers": {":message-type": "exception"},
"body": b"",
}
)

sse = decoder._parse_message_from_event(event)
assert sse is not None
assert sse.event == "error"
assert sse.data == ""