Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/anthropic/lib/bedrock/_stream_decoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Iterator, AsyncIterator

from ..._utils import lru_cache
from ..._utils import is_dict, lru_cache
from ..._streaming import ServerSentEvent

if TYPE_CHECKING:
Expand Down Expand Up @@ -37,7 +38,7 @@ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
yield ServerSentEvent(data=message, event="completion")
yield ServerSentEvent(data=message, event=self._get_sse_event_type(message))

async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
Expand All @@ -49,7 +50,16 @@ async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[Ser
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
yield ServerSentEvent(data=message, event="completion")
yield ServerSentEvent(data=message, event=self._get_sse_event_type(message))

def _get_sse_event_type(self, message: str) -> str:
try:
data = json.loads(message)
if is_dict(data) and data.get("type") == "error":
return "error"
except (ValueError, KeyError):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when it raises ValueError or KeyError? I think the only error the code above can raise it's json.decoder.JSONDecodeError during json.loads(message)

pass
return "completion"

def _parse_message_from_event(self, event: EventStreamMessage) -> str | None:
response_dict = event.to_response_dict()
Expand Down
64 changes: 64 additions & 0 deletions tests/lib/test_bedrock_stream_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import json

from anthropic.lib.bedrock._stream_decoder import AWSEventStreamDecoder


class TestGetSseEventType:
def test_error_event_returns_error(self) -> None:
decoder = _make_decoder()
message = json.dumps(
{
"type": "error",
"error": {"type": "rate_limit_error", "message": "Rate limited", "details": None},
"request_id": "req_123",
}
)
assert decoder._get_sse_event_type(message) == "error"

def test_message_start_event_returns_completion(self) -> None:
decoder = _make_decoder()
message = json.dumps(
{
"type": "message_start",
"message": {
"id": "msg_123",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-opus-4-7",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 10, "output_tokens": 0},
},
}
)
assert decoder._get_sse_event_type(message) == "completion"

def test_content_block_delta_returns_completion(self) -> None:
decoder = _make_decoder()
message = json.dumps(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": "Hello"},
}
)
assert decoder._get_sse_event_type(message) == "completion"

def test_invalid_json_returns_completion(self) -> None:
decoder = _make_decoder()
assert decoder._get_sse_event_type("not json") == "completion"

def test_non_dict_json_returns_completion(self) -> None:
decoder = _make_decoder()
assert decoder._get_sse_event_type('"just a string"') == "completion"

def test_dict_without_type_returns_completion(self) -> None:
decoder = _make_decoder()
assert decoder._get_sse_event_type('{"foo": "bar"}') == "completion"


def _make_decoder() -> AWSEventStreamDecoder:
return AWSEventStreamDecoder()