diff --git a/src/anthropic/lib/streaming/_beta_messages.py b/src/anthropic/lib/streaming/_beta_messages.py index 5a5a562a0..904aa99e7 100644 --- a/src/anthropic/lib/streaming/_beta_messages.py +++ b/src/anthropic/lib/streaming/_beta_messages.py @@ -14,6 +14,7 @@ from ..._types import NOT_GIVEN, NotGiven from ..._utils import consume_sync_iterator, consume_async_iterator +from ..._compat import parse_obj from ..._models import build, construct_type, construct_type_unchecked from ._beta_types import ( BetaCitationEvent, @@ -28,10 +29,25 @@ ) from ..._streaming import Stream, AsyncStream from ...types.beta import BetaRawMessageStreamEvent +from ...types.beta.beta_raw_message_start_event import BetaRawMessageStartEvent +from ...types.beta.beta_raw_message_delta_event import BetaRawMessageDeltaEvent +from ...types.beta.beta_raw_message_stop_event import BetaRawMessageStopEvent +from ...types.beta.beta_raw_content_block_start_event import BetaRawContentBlockStartEvent +from ...types.beta.beta_raw_content_block_delta_event import BetaRawContentBlockDeltaEvent +from ...types.beta.beta_raw_content_block_stop_event import BetaRawContentBlockStopEvent from ..._utils._utils import is_given from .._parse._response import ResponseFormatT, parse_text from ...types.beta.parsed_beta_message import ParsedBetaMessage, ParsedBetaContentBlock +_BETA_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = { + "message_start": BetaRawMessageStartEvent, + "message_delta": BetaRawMessageDeltaEvent, + "message_stop": BetaRawMessageStopEvent, + "content_block_start": BetaRawContentBlockStartEvent, + "content_block_delta": BetaRawContentBlockDeltaEvent, + "content_block_stop": BetaRawContentBlockStopEvent, +} + class BetaMessageStream(Generic[ResponseFormatT]): text_stream: Iterator[str] @@ -460,6 +476,21 @@ def accumulate_event( value=event, ), ) + if not isinstance(cast(Any, event), BaseModel): + # construct_type_unchecked can silently return the raw dict when Pydantic's + # union discriminator fails (e.g. certain Pydantic versions or TypeAlias + # nesting). Fall back to parse_obj on the concrete class so that nested + # fields (e.g. delta inside content_block_delta) are also validated and + # typed, not left as raw dicts. See https://github.com/anthropics/anthropic-sdk-python/issues/941 + raw = cast(Any, event) + if isinstance(raw, dict): + event_type = raw.get("type") + target_cls = _BETA_RAW_EVENT_TYPE_MAP.get(event_type) if isinstance(event_type, str) else None + if target_cls is not None: + try: + event = cast(BetaRawMessageStreamEvent, parse_obj(target_cls, raw)) + except Exception: + pass if not isinstance(cast(Any, event), BaseModel): raise TypeError( f"Unexpected event runtime type, after deserialising twice - {event} - {builtins.type(event)}" diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index 5c0da9992..04962a4ac 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -21,14 +21,30 @@ ParsedContentBlockStopEvent, ) from ...types import RawMessageStreamEvent +from ...types.raw_message_start_event import RawMessageStartEvent +from ...types.raw_message_delta_event import RawMessageDeltaEvent +from ...types.raw_message_stop_event import RawMessageStopEvent +from ...types.raw_content_block_start_event import RawContentBlockStartEvent +from ...types.raw_content_block_delta_event import RawContentBlockDeltaEvent +from ...types.raw_content_block_stop_event import RawContentBlockStopEvent from ..._types import NOT_GIVEN, NotGiven from ..._utils import consume_sync_iterator, consume_async_iterator +from ..._compat import parse_obj from ..._models import build, construct_type, construct_type_unchecked from ..._streaming import Stream, AsyncStream from ..._utils._utils import is_given from .._parse._response import ResponseFormatT, parse_text from ...types.parsed_message import ParsedMessage, ParsedContentBlock +_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = { + "message_start": RawMessageStartEvent, + "message_delta": RawMessageDeltaEvent, + "message_stop": RawMessageStopEvent, + "content_block_start": RawContentBlockStartEvent, + "content_block_delta": RawContentBlockDeltaEvent, + "content_block_stop": RawContentBlockStopEvent, +} + class MessageStream(Generic[ResponseFormatT]): text_stream: Iterator[str] @@ -444,6 +460,21 @@ def accumulate_event( value=event, ), ) + if not isinstance(cast(Any, event), BaseModel): + # construct_type_unchecked can silently return the raw dict when Pydantic's + # union discriminator fails (e.g. certain Pydantic versions or TypeAlias + # nesting). Fall back to parse_obj on the concrete class so that nested + # fields (e.g. delta inside content_block_delta) are also validated and + # typed, not left as raw dicts. See https://github.com/anthropics/anthropic-sdk-python/issues/941 + raw = cast(Any, event) + if isinstance(raw, dict): + event_type = raw.get("type") + target_cls = _RAW_EVENT_TYPE_MAP.get(event_type) if isinstance(event_type, str) else None + if target_cls is not None: + try: + event = cast(RawMessageStreamEvent, parse_obj(target_cls, raw)) + except Exception: + pass if not isinstance(cast(Any, event), BaseModel): raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}") diff --git a/tests/lib/streaming/test_messages.py b/tests/lib/streaming/test_messages.py index b86a39063..ddf61ff85 100644 --- a/tests/lib/streaming/test_messages.py +++ b/tests/lib/streaming/test_messages.py @@ -349,3 +349,121 @@ def test_tracks_tool_input_type_alias_is_up_to_date() -> None: f"ContentBlock type {block_type.__name__} has an input property, " f"but is not included in TRACKS_TOOL_INPUT. You probably need to update the TRACKS_TOOL_INPUT type alias." ) + + +# Regression tests for https://github.com/anthropics/anthropic-sdk-python/issues/941 +# +# When Pydantic's union discriminator fails silently (certain Pydantic versions or +# TypeAlias nesting), construct_type_unchecked returns the raw dict. The fallback in +# accumulate_event must produce fully-typed event objects — not just BaseModel shells +# with nested dicts — so that downstream attribute access like event.delta.type works. + + +def test_accumulate_event_raw_dict_text_delta_is_fully_typed() -> None: + """Raw dict content_block_delta events must be fully validated by the fallback so + that event.delta is a typed TextDelta object, not a raw dict. If it were a dict, + the event.delta.type access inside accumulate_event would raise AttributeError and + the text would never be appended to the snapshot.""" + from anthropic.lib.streaming._messages import accumulate_event + + snapshot = accumulate_event( + event={ # type: ignore[arg-type] + "type": "message_start", + "message": { + "id": "msg_test_941", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-opus-latest", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 10, "output_tokens": 0}, + }, + }, + current_snapshot=None, + ) + + snapshot = accumulate_event( + event={ # type: ignore[arg-type] + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + }, + current_snapshot=snapshot, + ) + + # This is the crux of issue #941: a content_block_delta arriving as a raw dict. + # The fallback must produce a RawContentBlockDeltaEvent whose .delta is a TextDelta + # (typed), not a dict — otherwise event.delta.type raises AttributeError and + # content.text is never updated. + snapshot = accumulate_event( + event={ # type: ignore[arg-type] + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "hello"}, + }, + current_snapshot=snapshot, + ) + + # If event.delta were a raw dict the text would still be "" here. + assert snapshot.content[0].text == "hello" + + +def test_accumulate_event_raw_dict_multiple_deltas_accumulate() -> None: + """Multiple raw dict text deltas must each append correctly, proving that the + fallback is re-applied on every event and not just the first one.""" + from anthropic.lib.streaming._messages import accumulate_event + + snapshot = accumulate_event( + event={ # type: ignore[arg-type] + "type": "message_start", + "message": { + "id": "msg_test_941b", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-opus-latest", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 10, "output_tokens": 0}, + }, + }, + current_snapshot=None, + ) + + snapshot = accumulate_event( + event={ # type: ignore[arg-type] + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + }, + current_snapshot=snapshot, + ) + + for word in ["Hello", " there", "!"]: + snapshot = accumulate_event( + event={ # type: ignore[arg-type] + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": word}, + }, + current_snapshot=snapshot, + ) + + assert snapshot.content[0].text == "Hello there!" + + +def test_accumulate_event_raw_dict_unknown_type_still_raises() -> None: + """An unknown event type in the raw dict must still raise an exception so that + genuinely malformed events are not silently swallowed by the fallback. + The exact exception depends on how far deserialization gets: if the event survives + as a BaseModel with the wrong type, RuntimeError is raised by the event-ordering + check; if deserialization returns the raw dict, TypeError is raised.""" + import pytest + from anthropic.lib.streaming._messages import accumulate_event + + with pytest.raises((TypeError, RuntimeError)): + accumulate_event( + event={"type": "unknown_future_event", "data": "x"}, # type: ignore[arg-type] + current_snapshot=None, + ) diff --git a/tests/lib/streaming/test_partial_json.py b/tests/lib/streaming/test_partial_json.py index 88a6e49f1..4e4241680 100644 --- a/tests/lib/streaming/test_partial_json.py +++ b/tests/lib/streaming/test_partial_json.py @@ -2,6 +2,7 @@ from typing import List, cast import httpx +import pytest from anthropic.types.beta import BetaDirectCaller, BetaToolUseBlock, BetaInputJSONDelta, BetaRawContentBlockDeltaEvent from anthropic.types.tool_use_block import ToolUseBlock @@ -104,7 +105,6 @@ def test_trailing_strings_mode_header(self) -> None: assert "unfinished_field" in trailing_input assert trailing_input["unfinished_field"] == "incomplete value" - # test that with invalid JSON we throw the correct error def test_partial_json_with_invalid_json(self) -> None: """Test that invalid JSON raises an error.""" message = ParsedBetaMessage( @@ -147,3 +147,65 @@ def test_partial_json_with_invalid_json(self) -> None: ) except Exception as e: raise AssertionError(f"Unexpected error type: {type(e).__name__} with message: {str(e)}") from e + + +# Regression tests for https://github.com/anthropics/anthropic-sdk-python/issues/941 (beta path) +# +# When construct_type_unchecked silently returns a raw dict, the beta accumulate_event +# fallback must produce a fully-validated BetaRawContentBlockDeltaEvent so that +# event.delta.type (a typed BetaTextDelta) is accessible, not a raw dict. + + +class TestBetaRawDictFallback: + def _make_snapshot(self) -> ParsedBetaMessage: # type: ignore[type-arg] + return ParsedBetaMessage( + id="msg_test_beta_941", + type="message", + role="assistant", + content=[{"type": "text", "text": ""}], + model="claude-3-opus-latest", + stop_reason=None, + stop_sequence=None, + usage=BetaUsage(input_tokens=10, output_tokens=0), + ) + + def test_raw_dict_text_delta_is_fully_typed(self) -> None: + """Raw dict content_block_delta must be fully validated so that event.delta is a + typed BetaTextDelta; otherwise event.delta.type raises AttributeError and the + text is never appended to the snapshot.""" + snapshot = accumulate_event( + event={ # type: ignore[arg-type] + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "hello"}, + }, + current_snapshot=self._make_snapshot(), + request_headers=httpx.Headers(), + ) + assert snapshot.content[0].text == "hello" + + def test_raw_dict_multiple_deltas_accumulate(self) -> None: + """Multiple raw dict text deltas must each append correctly.""" + snapshot = self._make_snapshot() + for word in ["Hello", " beta", "!"]: + snapshot = accumulate_event( + event={ # type: ignore[arg-type] + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": word}, + }, + current_snapshot=snapshot, + request_headers=httpx.Headers(), + ) + assert snapshot.content[0].text == "Hello beta!" + + def test_raw_dict_unknown_type_still_raises(self) -> None: + """An unknown event type arriving before message_start must not be silently + swallowed by the fallback — it must raise either TypeError (if deserialization + returns a raw dict) or RuntimeError (if it produces a BaseModel with wrong type).""" + with pytest.raises((TypeError, RuntimeError)): + accumulate_event( + event={"type": "unknown_future_event", "data": "x"}, # type: ignore[arg-type] + current_snapshot=None, + request_headers=httpx.Headers(), + )