From 43eb846a5b54bf41179fc2cd089bfd734a290696 Mon Sep 17 00:00:00 2001 From: Junhyuk Lee Date: Thu, 30 Apr 2026 00:15:21 -0500 Subject: [PATCH] fix: route Bedrock SSE error events to the error handler (#1472) Co-Authored-By: Claude Opus 4.6 --- src/anthropic/lib/bedrock/_stream_decoder.py | 16 ++++- tests/lib/test_bedrock_stream_decoder.py | 64 ++++++++++++++++++++ 2 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 tests/lib/test_bedrock_stream_decoder.py diff --git a/src/anthropic/lib/bedrock/_stream_decoder.py b/src/anthropic/lib/bedrock/_stream_decoder.py index 02e81a3ca..74e0eb6c3 100644 --- a/src/anthropic/lib/bedrock/_stream_decoder.py +++ b/src/anthropic/lib/bedrock/_stream_decoder.py @@ -1,8 +1,9 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Iterator, AsyncIterator -from ..._utils import lru_cache +from ..._utils import is_dict, lru_cache from ..._streaming import ServerSentEvent if TYPE_CHECKING: @@ -37,7 +38,7 @@ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: for event in event_stream_buffer: message = self._parse_message_from_event(event) if message: - yield ServerSentEvent(data=message, event="completion") + yield ServerSentEvent(data=message, event=self._get_sse_event_type(message)) async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: """Given an async iterator that yields lines, iterate over it & yield every event encountered""" @@ -49,7 +50,16 @@ async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[Ser for event in event_stream_buffer: message = self._parse_message_from_event(event) if message: - yield ServerSentEvent(data=message, event="completion") + yield ServerSentEvent(data=message, event=self._get_sse_event_type(message)) + + def _get_sse_event_type(self, message: str) -> str: + try: + data = json.loads(message) + if is_dict(data) and data.get("type") == "error": + return "error" + except (ValueError, KeyError): + pass + return "completion" def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: response_dict = event.to_response_dict() diff --git a/tests/lib/test_bedrock_stream_decoder.py b/tests/lib/test_bedrock_stream_decoder.py new file mode 100644 index 000000000..b8dc3cd80 --- /dev/null +++ b/tests/lib/test_bedrock_stream_decoder.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import json + +from anthropic.lib.bedrock._stream_decoder import AWSEventStreamDecoder + + +class TestGetSseEventType: + def test_error_event_returns_error(self) -> None: + decoder = _make_decoder() + message = json.dumps( + { + "type": "error", + "error": {"type": "rate_limit_error", "message": "Rate limited", "details": None}, + "request_id": "req_123", + } + ) + assert decoder._get_sse_event_type(message) == "error" + + def test_message_start_event_returns_completion(self) -> None: + decoder = _make_decoder() + message = json.dumps( + { + "type": "message_start", + "message": { + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-opus-4-7", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 10, "output_tokens": 0}, + }, + } + ) + assert decoder._get_sse_event_type(message) == "completion" + + def test_content_block_delta_returns_completion(self) -> None: + decoder = _make_decoder() + message = json.dumps( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Hello"}, + } + ) + assert decoder._get_sse_event_type(message) == "completion" + + def test_invalid_json_returns_completion(self) -> None: + decoder = _make_decoder() + assert decoder._get_sse_event_type("not json") == "completion" + + def test_non_dict_json_returns_completion(self) -> None: + decoder = _make_decoder() + assert decoder._get_sse_event_type('"just a string"') == "completion" + + def test_dict_without_type_returns_completion(self) -> None: + decoder = _make_decoder() + assert decoder._get_sse_event_type('{"foo": "bar"}') == "completion" + + +def _make_decoder() -> AWSEventStreamDecoder: + return AWSEventStreamDecoder()