Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/anthropic/lib/streaming/_beta_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..._types import NOT_GIVEN, NotGiven
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._compat import parse_obj
from ..._models import build, construct_type, construct_type_unchecked
from ._beta_types import (
BetaCitationEvent,
Expand All @@ -28,10 +29,25 @@
)
from ..._streaming import Stream, AsyncStream
from ...types.beta import BetaRawMessageStreamEvent
from ...types.beta.beta_raw_message_start_event import BetaRawMessageStartEvent
from ...types.beta.beta_raw_message_delta_event import BetaRawMessageDeltaEvent
from ...types.beta.beta_raw_message_stop_event import BetaRawMessageStopEvent
from ...types.beta.beta_raw_content_block_start_event import BetaRawContentBlockStartEvent
from ...types.beta.beta_raw_content_block_delta_event import BetaRawContentBlockDeltaEvent
from ...types.beta.beta_raw_content_block_stop_event import BetaRawContentBlockStopEvent
from ..._utils._utils import is_given
from .._parse._response import ResponseFormatT, parse_text
from ...types.beta.parsed_beta_message import ParsedBetaMessage, ParsedBetaContentBlock

_BETA_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = {
"message_start": BetaRawMessageStartEvent,
"message_delta": BetaRawMessageDeltaEvent,
"message_stop": BetaRawMessageStopEvent,
"content_block_start": BetaRawContentBlockStartEvent,
"content_block_delta": BetaRawContentBlockDeltaEvent,
"content_block_stop": BetaRawContentBlockStopEvent,
}


class BetaMessageStream(Generic[ResponseFormatT]):
text_stream: Iterator[str]
Expand Down Expand Up @@ -460,6 +476,21 @@ def accumulate_event(
value=event,
),
)
if not isinstance(cast(Any, event), BaseModel):
# construct_type_unchecked can silently return the raw dict when Pydantic's
# union discriminator fails (e.g. certain Pydantic versions or TypeAlias
# nesting). Fall back to parse_obj on the concrete class so that nested
# fields (e.g. delta inside content_block_delta) are also validated and
# typed, not left as raw dicts. See https://github.com/anthropics/anthropic-sdk-python/issues/941
raw = cast(Any, event)
if isinstance(raw, dict):
event_type = raw.get("type")
target_cls = _BETA_RAW_EVENT_TYPE_MAP.get(event_type) if isinstance(event_type, str) else None
if target_cls is not None:
try:
event = cast(BetaRawMessageStreamEvent, parse_obj(target_cls, raw))
except Exception:
pass
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(
f"Unexpected event runtime type, after deserialising twice - {event} - {builtins.type(event)}"
Expand Down
31 changes: 31 additions & 0 deletions src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,30 @@
ParsedContentBlockStopEvent,
)
from ...types import RawMessageStreamEvent
from ...types.raw_message_start_event import RawMessageStartEvent
from ...types.raw_message_delta_event import RawMessageDeltaEvent
from ...types.raw_message_stop_event import RawMessageStopEvent
from ...types.raw_content_block_start_event import RawContentBlockStartEvent
from ...types.raw_content_block_delta_event import RawContentBlockDeltaEvent
from ...types.raw_content_block_stop_event import RawContentBlockStopEvent
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._compat import parse_obj
from ..._models import build, construct_type, construct_type_unchecked
from ..._streaming import Stream, AsyncStream
from ..._utils._utils import is_given
from .._parse._response import ResponseFormatT, parse_text
from ...types.parsed_message import ParsedMessage, ParsedContentBlock

_RAW_EVENT_TYPE_MAP: dict[str, type[BaseModel]] = {
"message_start": RawMessageStartEvent,
"message_delta": RawMessageDeltaEvent,
"message_stop": RawMessageStopEvent,
"content_block_start": RawContentBlockStartEvent,
"content_block_delta": RawContentBlockDeltaEvent,
"content_block_stop": RawContentBlockStopEvent,
}


class MessageStream(Generic[ResponseFormatT]):
text_stream: Iterator[str]
Expand Down Expand Up @@ -444,6 +460,21 @@ def accumulate_event(
value=event,
),
)
if not isinstance(cast(Any, event), BaseModel):
# construct_type_unchecked can silently return the raw dict when Pydantic's
# union discriminator fails (e.g. certain Pydantic versions or TypeAlias
# nesting). Fall back to parse_obj on the concrete class so that nested
# fields (e.g. delta inside content_block_delta) are also validated and
# typed, not left as raw dicts. See https://github.com/anthropics/anthropic-sdk-python/issues/941
raw = cast(Any, event)
if isinstance(raw, dict):
event_type = raw.get("type")
target_cls = _RAW_EVENT_TYPE_MAP.get(event_type) if isinstance(event_type, str) else None
if target_cls is not None:
try:
event = cast(RawMessageStreamEvent, parse_obj(target_cls, raw))
except Exception:
pass
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")

Expand Down
118 changes: 118 additions & 0 deletions tests/lib/streaming/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,121 @@ def test_tracks_tool_input_type_alias_is_up_to_date() -> None:
f"ContentBlock type {block_type.__name__} has an input property, "
f"but is not included in TRACKS_TOOL_INPUT. You probably need to update the TRACKS_TOOL_INPUT type alias."
)


# Regression tests for https://github.com/anthropics/anthropic-sdk-python/issues/941
#
# When Pydantic's union discriminator fails silently (certain Pydantic versions or
# TypeAlias nesting), construct_type_unchecked returns the raw dict. The fallback in
# accumulate_event must produce fully-typed event objects — not just BaseModel shells
# with nested dicts — so that downstream attribute access like event.delta.type works.


def test_accumulate_event_raw_dict_text_delta_is_fully_typed() -> None:
"""Raw dict content_block_delta events must be fully validated by the fallback so
that event.delta is a typed TextDelta object, not a raw dict. If it were a dict,
the event.delta.type access inside accumulate_event would raise AttributeError and
the text would never be appended to the snapshot."""
from anthropic.lib.streaming._messages import accumulate_event

snapshot = accumulate_event(
event={ # type: ignore[arg-type]
"type": "message_start",
"message": {
"id": "msg_test_941",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-opus-latest",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 10, "output_tokens": 0},
},
},
current_snapshot=None,
)

snapshot = accumulate_event(
event={ # type: ignore[arg-type]
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
current_snapshot=snapshot,
)

# This is the crux of issue #941: a content_block_delta arriving as a raw dict.
# The fallback must produce a RawContentBlockDeltaEvent whose .delta is a TextDelta
# (typed), not a dict — otherwise event.delta.type raises AttributeError and
# content.text is never updated.
snapshot = accumulate_event(
event={ # type: ignore[arg-type]
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": "hello"},
},
current_snapshot=snapshot,
)

# If event.delta were a raw dict the text would still be "" here.
assert snapshot.content[0].text == "hello"


def test_accumulate_event_raw_dict_multiple_deltas_accumulate() -> None:
"""Multiple raw dict text deltas must each append correctly, proving that the
fallback is re-applied on every event and not just the first one."""
from anthropic.lib.streaming._messages import accumulate_event

snapshot = accumulate_event(
event={ # type: ignore[arg-type]
"type": "message_start",
"message": {
"id": "msg_test_941b",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-opus-latest",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 10, "output_tokens": 0},
},
},
current_snapshot=None,
)

snapshot = accumulate_event(
event={ # type: ignore[arg-type]
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
current_snapshot=snapshot,
)

for word in ["Hello", " there", "!"]:
snapshot = accumulate_event(
event={ # type: ignore[arg-type]
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": word},
},
current_snapshot=snapshot,
)

assert snapshot.content[0].text == "Hello there!"


def test_accumulate_event_raw_dict_unknown_type_still_raises() -> None:
"""An unknown event type in the raw dict must still raise an exception so that
genuinely malformed events are not silently swallowed by the fallback.
The exact exception depends on how far deserialization gets: if the event survives
as a BaseModel with the wrong type, RuntimeError is raised by the event-ordering
check; if deserialization returns the raw dict, TypeError is raised."""
import pytest
from anthropic.lib.streaming._messages import accumulate_event

with pytest.raises((TypeError, RuntimeError)):
accumulate_event(
event={"type": "unknown_future_event", "data": "x"}, # type: ignore[arg-type]
current_snapshot=None,
)
64 changes: 63 additions & 1 deletion tests/lib/streaming/test_partial_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, cast

import httpx
import pytest

from anthropic.types.beta import BetaDirectCaller, BetaToolUseBlock, BetaInputJSONDelta, BetaRawContentBlockDeltaEvent
from anthropic.types.tool_use_block import ToolUseBlock
Expand Down Expand Up @@ -104,7 +105,6 @@ def test_trailing_strings_mode_header(self) -> None:
assert "unfinished_field" in trailing_input
assert trailing_input["unfinished_field"] == "incomplete value"

# test that with invalid JSON we throw the correct error
def test_partial_json_with_invalid_json(self) -> None:
"""Test that invalid JSON raises an error."""
message = ParsedBetaMessage(
Expand Down Expand Up @@ -147,3 +147,65 @@ def test_partial_json_with_invalid_json(self) -> None:
)
except Exception as e:
raise AssertionError(f"Unexpected error type: {type(e).__name__} with message: {str(e)}") from e


# Regression tests for https://github.com/anthropics/anthropic-sdk-python/issues/941 (beta path)
#
# When construct_type_unchecked silently returns a raw dict, the beta accumulate_event
# fallback must produce a fully-validated BetaRawContentBlockDeltaEvent so that
# event.delta.type (a typed BetaTextDelta) is accessible, not a raw dict.


class TestBetaRawDictFallback:
def _make_snapshot(self) -> ParsedBetaMessage: # type: ignore[type-arg]
return ParsedBetaMessage(
id="msg_test_beta_941",
type="message",
role="assistant",
content=[{"type": "text", "text": ""}],
model="claude-3-opus-latest",
stop_reason=None,
stop_sequence=None,
usage=BetaUsage(input_tokens=10, output_tokens=0),
)

def test_raw_dict_text_delta_is_fully_typed(self) -> None:
"""Raw dict content_block_delta must be fully validated so that event.delta is a
typed BetaTextDelta; otherwise event.delta.type raises AttributeError and the
text is never appended to the snapshot."""
snapshot = accumulate_event(
event={ # type: ignore[arg-type]
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": "hello"},
},
current_snapshot=self._make_snapshot(),
request_headers=httpx.Headers(),
)
assert snapshot.content[0].text == "hello"

def test_raw_dict_multiple_deltas_accumulate(self) -> None:
"""Multiple raw dict text deltas must each append correctly."""
snapshot = self._make_snapshot()
for word in ["Hello", " beta", "!"]:
snapshot = accumulate_event(
event={ # type: ignore[arg-type]
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": word},
},
current_snapshot=snapshot,
request_headers=httpx.Headers(),
)
assert snapshot.content[0].text == "Hello beta!"

def test_raw_dict_unknown_type_still_raises(self) -> None:
"""An unknown event type arriving before message_start must not be silently
swallowed by the fallback — it must raise either TypeError (if deserialization
returns a raw dict) or RuntimeError (if it produces a BaseModel with wrong type)."""
with pytest.raises((TypeError, RuntimeError)):
accumulate_event(
event={"type": "unknown_future_event", "data": "x"}, # type: ignore[arg-type]
current_snapshot=None,
request_headers=httpx.Headers(),
)