diff --git a/docs/Plan-836.md b/docs/Plan-836.md new file mode 100644 index 000000000..295b4df62 --- /dev/null +++ b/docs/Plan-836.md @@ -0,0 +1,199 @@ +# Plan-836: ConversationManager Helper + +## Summary + +Add a `ConversationManager` (sync) and `AsyncConversationManager` (async) helper to `anthropic.helpers` that maintains multi-turn conversation history and auto-truncates the oldest messages when approaching a model's context window limit. + +--- + +## Problem + +Users building chatbots or agentic loops must manually manage `messages[]` history and handle `context_length_exceeded` errors themselves. There is no built-in helper in the SDK that: +- Maintains state across turns +- Protects against context overflow +- Follows the existing helper conventions (`RateLimitedClient`, `ResponseCache`, `RetryObserver`) + +--- + +## Files + +| Action | Path | +|----------|-------------------------------------------------------------| +| Create | `src/anthropic/helpers/conversation.py` | +| Create | `tests/helpers/test_conversation.py` | +| Create | `examples/helpers/conversation_example.py` | +| Modify | `src/anthropic/helpers/__init__.py` | + +--- + +## Class API + +```python +class ConversationManager: + def __init__( + self, + client: Any, + *, + model: str, + max_tokens: int, + system: str | None = None, + context_window_limit: int | None = None, + token_budget_headroom: float = 0.10, + accurate_token_counting: bool = False, + ) -> None: ... + + def add_user_message(self, content: str | list[Any]) -> None: ... + def get_response(self, content: str | list[Any] | None = None, **kwargs: Any) -> Any: ... + def reset(self) -> None: ... + + @property + def history(self) -> list[Any]: ... # shallow copy + + @property + def last_usage(self) -> Any | None: ... # Usage from last response +``` + +`AsyncConversationManager` mirrors the above with `async def get_response(...)`. + +### Constructor validation (raises `ValueError`) +- `model` is empty string +- `max_tokens < 1` +- `context_window_limit` provided but `< 1` +- `token_budget_headroom` not in `[0.0, 1.0)` + +--- + +## `get_response()` Flow + +``` +1. If content is not None → self.add_user_message(content) +2. If history empty or history[-1]["role"] != "user" → raise ValueError +3. If context_window_limit is set → _truncate_if_needed() +4. response = client.messages.create( + messages=list(self._history), + model=self._model, + max_tokens=self._max_tokens, + **{"system": self._system} if self._system else {}, + **kwargs, + ) +5. Append {"role": "assistant", "content": response.content} to history +6. self._last_usage = response.usage +7. return response +``` + +--- + +## Truncation Algorithm (`_truncate_if_needed`) + +``` +threshold = context_window_limit * (1.0 - token_budget_headroom) + +Estimate tokens: + accurate=True → call client.messages.count_tokens(history, model, system) + accurate=False → use last_usage.input_tokens + last_usage.output_tokens + (None on first call → skip truncation) + +while estimated_tokens >= threshold: + if len(history) < 2: + raise ValueError("cannot truncate further — single message pair exceeds limit") + pair_fraction = 2 / len(history) + history.pop(0) # oldest user + history.pop(0) # oldest assistant + if accurate=True: + re-call count_tokens to refresh estimate + else: + estimated_tokens = int(estimated_tokens * (1.0 - pair_fraction)) +``` + +**Design decisions:** +- Drop oldest user+assistant **pairs** to maintain role-alternation invariant +- Heuristic mode (default): uses `last_usage` — zero extra API calls +- Accurate mode: calls `count_tokens()` — precise, adds latency per loop +- First call with `last_usage=None` → skip truncation +- History exhausted before threshold → `ValueError` with model + limit + suggestion + +--- + +## `__init__.py` Changes + +```python +from .conversation import ConversationManager, AsyncConversationManager + +__all__ = [ + ..., + "ConversationManager", + "AsyncConversationManager", +] +``` + +--- + +## Test Coverage (`tests/helpers/test_conversation.py`) + +### `class TestConversationManager` +- Constructor raises on: empty model, zero `max_tokens`, negative `context_window_limit`, invalid `token_budget_headroom` +- `add_user_message`: appends to history; raises on empty content +- `get_response`: calls API once, returns Message, appends assistant turn +- `get_response` with pre-staged message (no `content` arg) +- Multi-turn: 2 calls → 4 messages in history +- `last_usage` is `None` initially; populated after first call +- `**kwargs` forwarded to `messages.create` (e.g. `temperature=0.5`) +- System prompt passed when set; omitted when `None` +- No staged message raises `ValueError` +- `history` returns a copy (mutating it doesn't affect internal state) +- `reset()` clears history and `last_usage`; model/system unchanged +- Truncation: no-op when `context_window_limit=None` +- Truncation: no-op when under threshold +- Truncation: drops oldest pair when over threshold +- Truncation: drops multiple pairs until under threshold +- Truncation: raises `ValueError` when single pair still exceeds limit +- No truncation on first call (`last_usage=None`, heuristic mode) +- Accurate mode: `count_tokens` called; pairs dropped until under threshold + +### `class TestAsyncConversationManager` +- Mirrors key cases using `AsyncMock` for `messages.create` and `messages.count_tokens` + +### Mock helpers +```python +def _make_sync_client(*, input_tokens=100, output_tokens=50, content_text="Hello") -> MagicMock +def _make_async_client(*, input_tokens=100, output_tokens=50, content_text="Hello") -> MagicMock +``` + +--- + +## Example Script (`examples/helpers/conversation_example.py`) + +Demonstrates: +1. Sync `ConversationManager` — two-turn conversation, print usage, reset +2. Async `AsyncConversationManager` — same flow with `asyncio.run()` + +--- + +## Coding Conventions (match existing helpers) + +- `from __future__ import annotations` at top +- `from typing import Any, Optional` — use `Any` for client to avoid circular imports +- Module-level docstring with `Example::` block (RST format) +- Keyword-only args after first positional (`client`) +- Validate inputs early, raise `ValueError` with clear messages +- Thread safety: not required (document that each instance is single-threaded) +- Store `response.content` (full `List[ContentBlock]`) as assistant message — not just `.text` +- `__repr__` showing model, turn count, and limit + +--- + +## Verification + +```bash +# Run new tests +python -m pytest tests/helpers/test_conversation.py -v + +# Run full helper suite +python -m pytest tests/helpers/ -v + +# Verify imports +python -c "from anthropic.helpers import ConversationManager, AsyncConversationManager; print('OK')" + +# Run example (requires ANTHROPIC_API_KEY) +python examples/helpers/conversation_example.py +``` diff --git a/docs/review-836.md b/docs/review-836.md new file mode 100644 index 000000000..92067c445 --- /dev/null +++ b/docs/review-836.md @@ -0,0 +1,149 @@ +# Code Review Report: RAP-836 ConversationManager Helper + +**Reviewer:** Senior Python Code Analyst +**Date:** 2026-04-30 +**Plan:** docs/Plan-836.md +**Outcome:** `compliant` + +--- + +## Summary + +The implementation of `ConversationManager` and `AsyncConversationManager` has been reviewed line-by-line against all requirements and acceptance criteria in Plan-836.md. The code is **compliant** with no logical errors, requirement mismatches, or runtime issues detected. All previously identified issues (from earlier review iterations) have been resolved. + +--- + +## Files Reviewed + +| File | Status | +|------|--------| +| `src/anthropic/helpers/conversation.py` | Compliant | +| `src/anthropic/helpers/__init__.py` | Compliant | +| `tests/helpers/test_conversation.py` | Compliant | +| `examples/helpers/conversation_example.py` | Compliant | + +--- + +## Requirements Compliance + +### Class API + +| Requirement | Status | Notes | +|---|---|---| +| `ConversationManager` constructor signature | Pass | All params match plan: `client`, `model`, `max_tokens`, `system`, `context_window_limit`, `token_budget_headroom`, `accurate_token_counting` | +| `AsyncConversationManager` mirrors sync with `async def get_response()` | Pass | Lines 367-418; properly `await`s API calls and truncation | +| `add_user_message(content: str \| list)` | Pass | Lines 101-124 (sync), 342-365 (async) | +| `get_response(content, **kwargs)` | Pass | Lines 126-177 (sync), 367-418 (async) | +| `reset()` clears history + usage, preserves config | Pass | Lines 179-185 (sync), 420-426 (async) | +| `history` property returns shallow copy | Pass | `list(self._history)` | +| `last_usage` property | Pass | None initially, populated after each call | +| `__repr__` with model, turn count, limit | Pass | Lines 264-272 (sync), 505-513 (async) | + +### Constructor Validation (raises `ValueError`) + +| Validation | Status | Code | +|---|---|---| +| Empty `model` string | Pass | Line 73-74 | +| `max_tokens < 1` | Pass | Line 75-76 | +| `context_window_limit` provided but `< 1` | Pass | Lines 77-80 | +| `token_budget_headroom` not in `[0.0, 1.0)` | Pass | Lines 81-84 | + +### `get_response()` Flow (7 Steps) + +| Step | Requirement | Status | Code | +|---|---|---|---| +| 1 | If content not None, call `add_user_message(content)` | Pass | Lines 151-152 | +| 2 | If history empty or last role != "user", raise ValueError | Pass | Lines 154-158 | +| 3 | If `context_window_limit` set, call `_truncate_if_needed()` | Pass | Lines 160-161 | +| 4 | Call `client.messages.create()` with messages, model, max_tokens, system, kwargs | Pass | Lines 163-173 | +| 5 | Append `{"role": "assistant", "content": response.content}` | Pass | Line 175 | +| 6 | Store `response.usage` in `_last_usage` | Pass | Line 176 | +| 7 | Return response | Pass | Line 177 | + +### Truncation Algorithm (`_truncate_if_needed`) + +| Requirement | Status | Code | +|---|---|---| +| `threshold = limit * (1.0 - headroom)` | Pass | Line 230 | +| Accurate mode: calls `count_tokens(history, model, system)` | Pass | Lines 210-219 | +| Heuristic mode: uses `input_tokens + output_tokens` | Pass | Lines 220-225 | +| First call with `last_usage=None` skips truncation | Pass | Lines 233-235 | +| While `estimated >= threshold`, drop oldest user+assistant pair | Pass | Lines 237-262 | +| `len(history) < 2` raises ValueError with model + limit | Pass | Lines 238-246 | +| `pair_fraction = 2 / len(history)` computed before pops | Pass | Line 255 | +| Accurate mode re-calls `count_tokens` after each pair drop | Pass | Lines 259-260 | +| Heuristic mode: `int(estimated * (1.0 - pair_fraction))` | Pass | Lines 261-262 | + +### `__init__.py` Changes + +| Requirement | Status | +|---|---| +| Imports `ConversationManager` and `AsyncConversationManager` | Pass | +| Both in `__all__` | Pass | + +### Test Coverage + +| Required Test Case | Status | +|---|---| +| Constructor raises on empty model | Pass | +| Constructor raises on zero/negative max_tokens | Pass | +| Constructor raises on negative context_window_limit | Pass | +| Constructor raises on invalid token_budget_headroom | Pass | +| `add_user_message` appends; raises on empty content | Pass | +| `get_response` calls API once, returns Message, appends assistant | Pass | +| `get_response` with pre-staged message (no content arg) | Pass | +| Multi-turn: 2 calls -> 4 messages | Pass | +| `last_usage` None initially; populated after first call | Pass | +| `**kwargs` forwarded to `messages.create` | Pass | +| System prompt passed when set; omitted when None | Pass | +| No staged message raises ValueError | Pass | +| `history` returns copy (mutation doesn't affect state) | Pass | +| `reset()` clears history and last_usage; preserves model/system | Pass | +| Truncation no-op when `context_window_limit=None` | Pass | +| Truncation no-op when under threshold | Pass | +| Truncation drops oldest pair when over threshold | Pass | +| Truncation drops multiple pairs until under threshold | Pass | +| Truncation raises ValueError when single pair exceeds limit | Pass | +| No truncation on first call (heuristic, `last_usage=None`) | Pass | +| Accurate mode: `count_tokens` called; pairs dropped until under | Pass | +| Async mirrors key sync cases | Pass | + +### Coding Conventions + +| Convention | Status | +|---|---| +| `from __future__ import annotations` | Pass | +| `from typing import Any, Optional` | Pass | +| Module-level docstring with `Example::` RST block | Pass | +| Keyword-only args after positional `client` | Pass | +| Early input validation with `ValueError` | Pass | +| Thread safety documented | Pass | +| `response.content` stored as full content block list | Pass | +| `__repr__` showing model, turn count, limit | Pass | + +### Example Script + +| Requirement | Status | +|---|---| +| Sync two-turn conversation, print usage, reset | Pass | +| Async same flow with `asyncio.run()` | Pass | + +--- + +## Observations (non-blocking, informational only) + +1. **Extra defensive guards beyond plan spec:** `add_user_message()` includes a role-alternation guard (lines 119-123) and `_truncate_if_needed()` validates pair ordering before popping (lines 247-254). These are not in the plan pseudocode but are sound defensive measures that prevent invariant violations. Fully tested. + +2. **`__init__.py` module docstring scope:** The docstring references "rate limiting, caching, retry observability" alongside "conversation management." Only conversation management exists in this module currently. Plan-836.md references existing helpers (`RateLimitedClient`, `ResponseCache`, `RetryObserver`) from a parallel branch (RAP-437). At merge time, ensure `__init__.py` combines exports from both branches. + +3. **`list.pop(0)` is O(n):** Each pair removal shifts all remaining elements. For typical conversation lengths (tens to hundreds of messages), this is negligible. The plan does not specify performance requirements. Noted for future consideration only. + +4. **Heuristic token estimate is conservative by design:** The heuristic uses `input_tokens + output_tokens` from the previous response, which slightly overestimates (doesn't account for newly added user message tokens). This is explicitly acknowledged in the plan as "slightly less precise" and in the code's docstring. + +--- + +## Verdict + +**Outcome: `compliant`** + +The implementation correctly satisfies all requirements and acceptance criteria defined in Plan-836.md. No logical errors, control flow issues, boundary condition failures, type mismatches, or requirement deviations were identified. Test coverage is comprehensive and matches all specified test cases. diff --git a/examples/helpers/conversation_example.py b/examples/helpers/conversation_example.py new file mode 100644 index 000000000..dea95b3a5 --- /dev/null +++ b/examples/helpers/conversation_example.py @@ -0,0 +1,116 @@ +"""Example: Using ConversationManager and AsyncConversationManager. + +Demonstrates: +1. Sync ConversationManager — two-turn conversation, print usage, reset. +2. Async AsyncConversationManager — same flow with asyncio.run(). + +Requirements: + ANTHROPIC_API_KEY environment variable must be set. + +Usage:: + + python examples/helpers/conversation_example.py +""" + +from __future__ import annotations + +import asyncio + +import anthropic +from anthropic.helpers import AsyncConversationManager, ConversationManager + + +# --------------------------------------------------------------------------- +# 1. Sync example +# --------------------------------------------------------------------------- + + +def sync_example() -> None: + print("=" * 60) + print("Sync ConversationManager") + print("=" * 60) + + client = anthropic.Anthropic() + + mgr = ConversationManager( + client, + model="claude-opus-4-5", + max_tokens=256, + system="You are a helpful assistant. Be concise.", + context_window_limit=200_000, + token_budget_headroom=0.10, + ) + + print(repr(mgr)) + + # Turn 1 + response = mgr.get_response("What is the capital of France?") + print(f"\n[Turn 1] User: What is the capital of France?") + print(f"[Turn 1] Assistant: {response.content[0].text}") + print(f"[Turn 1] Usage: {mgr.last_usage}") + + # Turn 2 + response = mgr.get_response("And what language do they speak there?") + print(f"\n[Turn 2] User: And what language do they speak there?") + print(f"[Turn 2] Assistant: {response.content[0].text}") + print(f"[Turn 2] Usage: {mgr.last_usage}") + + print(f"\nHistory length: {len(mgr.history)} messages") + print(repr(mgr)) + + # Reset + mgr.reset() + print(f"\nAfter reset — history length: {len(mgr.history)}, last_usage: {mgr.last_usage}") + + +# --------------------------------------------------------------------------- +# 2. Async example +# --------------------------------------------------------------------------- + + +async def async_example() -> None: + print("\n" + "=" * 60) + print("Async AsyncConversationManager") + print("=" * 60) + + client = anthropic.AsyncAnthropic() + + mgr = AsyncConversationManager( + client, + model="claude-opus-4-5", + max_tokens=256, + system="You are a helpful assistant. Be concise.", + context_window_limit=200_000, + token_budget_headroom=0.10, + ) + + print(repr(mgr)) + + # Turn 1 + response = await mgr.get_response("What is the tallest mountain on Earth?") + print(f"\n[Turn 1] User: What is the tallest mountain on Earth?") + print(f"[Turn 1] Assistant: {response.content[0].text}") + print(f"[Turn 1] Usage: {mgr.last_usage}") + + # Turn 2 + response = await mgr.get_response("How tall is it in feet?") + print(f"\n[Turn 2] User: How tall is it in feet?") + print(f"[Turn 2] Assistant: {response.content[0].text}") + print(f"[Turn 2] Usage: {mgr.last_usage}") + + print(f"\nHistory length: {len(mgr.history)} messages") + print(repr(mgr)) + + # Reset + mgr.reset() + print(f"\nAfter reset — history length: {len(mgr.history)}, last_usage: {mgr.last_usage}") + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +if __name__ == "__main__": + sync_example() + asyncio.run(async_example()) diff --git a/src/anthropic/helpers/__init__.py b/src/anthropic/helpers/__init__.py new file mode 100644 index 000000000..f5eba255d --- /dev/null +++ b/src/anthropic/helpers/__init__.py @@ -0,0 +1,10 @@ +"""Anthropic SDK helpers — rate limiting, caching, retry observability, and conversation management.""" + +from __future__ import annotations + +from .conversation import AsyncConversationManager, ConversationManager + +__all__ = [ + "ConversationManager", + "AsyncConversationManager", +] diff --git a/src/anthropic/helpers/conversation.py b/src/anthropic/helpers/conversation.py new file mode 100644 index 000000000..5aaadfe15 --- /dev/null +++ b/src/anthropic/helpers/conversation.py @@ -0,0 +1,513 @@ +"""ConversationManager helpers for multi-turn conversation management. + +Provides :class:`ConversationManager` (sync) and +:class:`AsyncConversationManager` (async) that maintain multi-turn conversation +history and auto-truncate the oldest messages when approaching a model's +context-window limit. + +.. note:: + Each instance is **not** thread-safe. If you need concurrent access, + create one :class:`ConversationManager` per thread/task. + +Example:: + + import anthropic + from anthropic.helpers import ConversationManager + + client = anthropic.Anthropic() + mgr = ConversationManager( + client, + model="claude-opus-4-5", + max_tokens=1024, + system="You are a helpful assistant.", + context_window_limit=200_000, + ) + response = mgr.get_response("Hello!") + print(response.content[0].text) + print(mgr.last_usage) +""" + +from __future__ import annotations + +from typing import Any, Optional + + +class ConversationManager: + """Sync helper that maintains multi-turn conversation history. + + Parameters + ---------- + client: + A synchronous :class:`anthropic.Anthropic` client instance. + model: + The model identifier to use for all API calls. + max_tokens: + Maximum number of tokens to generate per response. + system: + Optional system prompt, passed verbatim to each API call. + context_window_limit: + If set, the manager will truncate the oldest message pairs whenever + the estimated token count approaches this limit (minus the headroom). + token_budget_headroom: + Fraction of ``context_window_limit`` to reserve as safety headroom. + Must be in ``[0.0, 1.0)``. Defaults to ``0.10`` (10 %). + accurate_token_counting: + When ``True`` the manager calls ``client.messages.count_tokens()`` + for accurate truncation decisions (adds one extra API call per + truncation loop iteration). When ``False`` (default) it uses the + ``input_tokens + output_tokens`` from the last response — zero + extra API calls but slightly less precise. + """ + + def __init__( + self, + client: Any, + *, + model: str, + max_tokens: int, + system: Optional[str] = None, + context_window_limit: Optional[int] = None, + token_budget_headroom: float = 0.10, + accurate_token_counting: bool = False, + ) -> None: + if not model: + raise ValueError("'model' must not be an empty string.") + if max_tokens < 1: + raise ValueError(f"'max_tokens' must be >= 1, got {max_tokens}.") + if context_window_limit is not None and context_window_limit < 1: + raise ValueError( + f"'context_window_limit' must be >= 1 when provided, got {context_window_limit}." + ) + if not (0.0 <= token_budget_headroom < 1.0): + raise ValueError( + f"'token_budget_headroom' must be in [0.0, 1.0), got {token_budget_headroom}." + ) + + self._client = client + self._model = model + self._max_tokens = max_tokens + self._system = system + self._context_window_limit = context_window_limit + self._token_budget_headroom = token_budget_headroom + self._accurate_token_counting = accurate_token_counting + + self._history: list[Any] = [] + self._last_usage: Any = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def add_user_message(self, content: str | list[Any]) -> None: + """Append a user message to the conversation history. + + Parameters + ---------- + content: + The message content — either a plain string or a list of content + blocks (e.g. image + text). + + Raises + ------ + ValueError + If *content* is an empty string or an empty list. + """ + if isinstance(content, str) and not content: + raise ValueError("'content' must not be an empty string.") + if isinstance(content, list) and len(content) == 0: + raise ValueError("'content' must not be an empty list.") + if self._history and self._history[-1]["role"] == "user": + raise ValueError( + "Cannot add a user message when the last message is already from " + "the user. The Anthropic API requires strict user/assistant alternation." + ) + self._history.append({"role": "user", "content": content}) + + def get_response( + self, content: Optional[str | list[Any]] = None, **kwargs: Any + ) -> Any: + """Send the current conversation to the API and return the response. + + Parameters + ---------- + content: + If provided, it is appended as a user message via + :meth:`add_user_message` before making the API call. + **kwargs: + Additional keyword arguments forwarded verbatim to + ``client.messages.create()``. + + Returns + ------- + anthropic.types.Message + The raw API response object. + + Raises + ------ + ValueError + If the conversation history is empty or does not end with a user + message, or if truncation is impossible. + """ + if content is not None: + self.add_user_message(content) + + if not self._history or self._history[-1]["role"] != "user": + raise ValueError( + "The conversation history must end with a user message before " + "calling get_response()." + ) + + if self._context_window_limit is not None: + self._truncate_if_needed() + + extra: dict[str, Any] = {} + if self._system is not None: + extra["system"] = self._system + + response = self._client.messages.create( + messages=list(self._history), + model=self._model, + max_tokens=self._max_tokens, + **extra, + **kwargs, + ) + + self._history.append({"role": "assistant", "content": response.content}) + self._last_usage = response.usage + return response + + def reset(self) -> None: + """Clear conversation history and last usage. + + Model, system prompt, and all configuration options are preserved. + """ + self._history = [] + self._last_usage = None + + @property + def history(self) -> list[Any]: + """Return a shallow copy of the conversation history.""" + return list(self._history) + + @property + def last_usage(self) -> Any: + """Usage object from the most recent API response, or ``None``.""" + return self._last_usage + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _estimate_tokens(self) -> Optional[int]: + """Return current token estimate or ``None`` if unavailable. + + In **accurate** mode, calls ``count_tokens`` for a precise input-only count. + In **heuristic** mode, uses ``input_tokens + output_tokens`` from the last + response — the previous output is now part of the conversation history, so + summing both provides a conservative (slightly over-) estimate without an + extra API call. + """ + if self._accurate_token_counting: + extra: dict[str, Any] = {} + if self._system is not None: + extra["system"] = self._system + result = self._client.messages.count_tokens( + messages=list(self._history), + model=self._model, + **extra, + ) + return result.input_tokens + else: + if self._last_usage is None: + return None + # Previous output tokens are now part of the conversation input, + # so summing both gives a conservative estimate. + return self._last_usage.input_tokens + self._last_usage.output_tokens + + def _truncate_if_needed(self) -> None: + """Drop the oldest user+assistant pairs until under the token threshold.""" + assert self._context_window_limit is not None + threshold = self._context_window_limit * (1.0 - self._token_budget_headroom) + + estimated = self._estimate_tokens() + if estimated is None: + # First call in heuristic mode — skip truncation. + return + + while estimated >= threshold: + if len(self._history) < 2: + raise ValueError( + f"Cannot truncate further — a single message pair already " + f"exceeds the token threshold for model '{self._model}' " + f"(limit={self._context_window_limit}, " + f"headroom={self._token_budget_headroom}). " + f"Consider increasing 'context_window_limit' or reducing " + f"the size of individual messages." + ) + if ( + self._history[0]["role"] != "user" + or self._history[1]["role"] != "assistant" + ): + raise ValueError( + "History role-alternation invariant violated; cannot truncate safely. " + "Expected [user, assistant] pair at the start of history." + ) + pair_fraction = 2 / len(self._history) + self._history.pop(0) # oldest user message + self._history.pop(0) # oldest assistant message + + if self._accurate_token_counting: + estimated = self._estimate_tokens() # type: ignore[assignment] + else: + estimated = int(estimated * (1.0 - pair_fraction)) + + def __repr__(self) -> str: + turns = len(self._history) // 2 + limit = self._context_window_limit + return ( + f"ConversationManager(" + f"model={self._model!r}, " + f"turns={turns}, " + f"context_window_limit={limit!r})" + ) + + +class AsyncConversationManager: + """Async helper that maintains multi-turn conversation history. + + Mirrors :class:`ConversationManager` but exposes ``async def get_response()``, + suitable for use inside ``asyncio`` event loops. + + Parameters + ---------- + client: + An asynchronous :class:`anthropic.AsyncAnthropic` client instance. + model: + The model identifier to use for all API calls. + max_tokens: + Maximum number of tokens to generate per response. + system: + Optional system prompt, passed verbatim to each API call. + context_window_limit: + If set, the manager will truncate the oldest message pairs whenever + the estimated token count approaches this limit (minus the headroom). + token_budget_headroom: + Fraction of ``context_window_limit`` to reserve as safety headroom. + Must be in ``[0.0, 1.0)``. Defaults to ``0.10`` (10 %). + accurate_token_counting: + When ``True`` the manager calls ``client.messages.count_tokens()`` + for accurate truncation decisions. When ``False`` (default) it uses + the ``input_tokens + output_tokens`` from the last response. + """ + + def __init__( + self, + client: Any, + *, + model: str, + max_tokens: int, + system: Optional[str] = None, + context_window_limit: Optional[int] = None, + token_budget_headroom: float = 0.10, + accurate_token_counting: bool = False, + ) -> None: + if not model: + raise ValueError("'model' must not be an empty string.") + if max_tokens < 1: + raise ValueError(f"'max_tokens' must be >= 1, got {max_tokens}.") + if context_window_limit is not None and context_window_limit < 1: + raise ValueError( + f"'context_window_limit' must be >= 1 when provided, got {context_window_limit}." + ) + if not (0.0 <= token_budget_headroom < 1.0): + raise ValueError( + f"'token_budget_headroom' must be in [0.0, 1.0), got {token_budget_headroom}." + ) + + self._client = client + self._model = model + self._max_tokens = max_tokens + self._system = system + self._context_window_limit = context_window_limit + self._token_budget_headroom = token_budget_headroom + self._accurate_token_counting = accurate_token_counting + + self._history: list[Any] = [] + self._last_usage: Any = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def add_user_message(self, content: str | list[Any]) -> None: + """Append a user message to the conversation history. + + Parameters + ---------- + content: + The message content — either a plain string or a list of content + blocks. + + Raises + ------ + ValueError + If *content* is an empty string or an empty list. + """ + if isinstance(content, str) and not content: + raise ValueError("'content' must not be an empty string.") + if isinstance(content, list) and len(content) == 0: + raise ValueError("'content' must not be an empty list.") + if self._history and self._history[-1]["role"] == "user": + raise ValueError( + "Cannot add a user message when the last message is already from " + "the user. The Anthropic API requires strict user/assistant alternation." + ) + self._history.append({"role": "user", "content": content}) + + async def get_response( + self, content: Optional[str | list[Any]] = None, **kwargs: Any + ) -> Any: + """Send the current conversation to the API and return the response. + + Parameters + ---------- + content: + If provided, it is appended as a user message via + :meth:`add_user_message` before making the API call. + **kwargs: + Additional keyword arguments forwarded verbatim to + ``client.messages.create()``. + + Returns + ------- + anthropic.types.Message + The raw API response object. + + Raises + ------ + ValueError + If the conversation history is empty or does not end with a user + message, or if truncation is impossible. + """ + if content is not None: + self.add_user_message(content) + + if not self._history or self._history[-1]["role"] != "user": + raise ValueError( + "The conversation history must end with a user message before " + "calling get_response()." + ) + + if self._context_window_limit is not None: + await self._truncate_if_needed() + + extra: dict[str, Any] = {} + if self._system is not None: + extra["system"] = self._system + + response = await self._client.messages.create( + messages=list(self._history), + model=self._model, + max_tokens=self._max_tokens, + **extra, + **kwargs, + ) + + self._history.append({"role": "assistant", "content": response.content}) + self._last_usage = response.usage + return response + + def reset(self) -> None: + """Clear conversation history and last usage. + + Model, system prompt, and all configuration options are preserved. + """ + self._history = [] + self._last_usage = None + + @property + def history(self) -> list[Any]: + """Return a shallow copy of the conversation history.""" + return list(self._history) + + @property + def last_usage(self) -> Any: + """Usage object from the most recent API response, or ``None``.""" + return self._last_usage + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + async def _estimate_tokens(self) -> Optional[int]: + """Return current token estimate or ``None`` if unavailable. + + In **accurate** mode, calls ``count_tokens`` for a precise input-only count. + In **heuristic** mode, uses ``input_tokens + output_tokens`` from the last + response — the previous output is now part of the conversation history, so + summing both provides a conservative (slightly over-) estimate without an + extra API call. + """ + if self._accurate_token_counting: + extra: dict[str, Any] = {} + if self._system is not None: + extra["system"] = self._system + result = await self._client.messages.count_tokens( + messages=list(self._history), + model=self._model, + **extra, + ) + return result.input_tokens + else: + if self._last_usage is None: + return None + # Previous output tokens are now part of the conversation input, + # so summing both gives a conservative estimate. + return self._last_usage.input_tokens + self._last_usage.output_tokens + + async def _truncate_if_needed(self) -> None: + """Drop the oldest user+assistant pairs until under the token threshold.""" + assert self._context_window_limit is not None + threshold = self._context_window_limit * (1.0 - self._token_budget_headroom) + + estimated = await self._estimate_tokens() + if estimated is None: + # First call in heuristic mode — skip truncation. + return + + while estimated >= threshold: + if len(self._history) < 2: + raise ValueError( + f"Cannot truncate further — a single message pair already " + f"exceeds the token threshold for model '{self._model}' " + f"(limit={self._context_window_limit}, " + f"headroom={self._token_budget_headroom}). " + f"Consider increasing 'context_window_limit' or reducing " + f"the size of individual messages." + ) + if ( + self._history[0]["role"] != "user" + or self._history[1]["role"] != "assistant" + ): + raise ValueError( + "History role-alternation invariant violated; cannot truncate safely. " + "Expected [user, assistant] pair at the start of history." + ) + pair_fraction = 2 / len(self._history) + self._history.pop(0) # oldest user message + self._history.pop(0) # oldest assistant message + + if self._accurate_token_counting: + estimated = await self._estimate_tokens() # type: ignore[assignment] + else: + estimated = int(estimated * (1.0 - pair_fraction)) + + def __repr__(self) -> str: + turns = len(self._history) // 2 + limit = self._context_window_limit + return ( + f"AsyncConversationManager(" + f"model={self._model!r}, " + f"turns={turns}, " + f"context_window_limit={limit!r})" + ) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/helpers/test_conversation.py b/tests/helpers/test_conversation.py new file mode 100644 index 000000000..03187d5eb --- /dev/null +++ b/tests/helpers/test_conversation.py @@ -0,0 +1,693 @@ +"""Tests for ConversationManager and AsyncConversationManager helpers.""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, MagicMock, call + +from anthropic.helpers import ConversationManager, AsyncConversationManager + + +# --------------------------------------------------------------------------- +# Mock helpers +# --------------------------------------------------------------------------- + + +def _make_usage(*, input_tokens: int = 100, output_tokens: int = 50) -> MagicMock: + usage = MagicMock() + usage.input_tokens = input_tokens + usage.output_tokens = output_tokens + return usage + + +def _make_message( + *, input_tokens: int = 100, output_tokens: int = 50, content_text: str = "Hello" +) -> MagicMock: + msg = MagicMock() + content_block = MagicMock() + content_block.text = content_text + msg.content = [content_block] + msg.usage = _make_usage(input_tokens=input_tokens, output_tokens=output_tokens) + return msg + + +def _make_sync_client( + *, input_tokens: int = 100, output_tokens: int = 50, content_text: str = "Hello" +) -> MagicMock: + """Return a MagicMock that mimics a synchronous Anthropic client.""" + client = MagicMock() + msg = _make_message( + input_tokens=input_tokens, + output_tokens=output_tokens, + content_text=content_text, + ) + client.messages.create.return_value = msg + # count_tokens returns an object with .input_tokens + ct_result = MagicMock() + ct_result.input_tokens = input_tokens + output_tokens + client.messages.count_tokens.return_value = ct_result + return client + + +def _make_async_client( + *, input_tokens: int = 100, output_tokens: int = 50, content_text: str = "Hello" +) -> MagicMock: + """Return a MagicMock that mimics an asynchronous Anthropic client.""" + client = MagicMock() + msg = _make_message( + input_tokens=input_tokens, + output_tokens=output_tokens, + content_text=content_text, + ) + client.messages.create = AsyncMock(return_value=msg) + ct_result = MagicMock() + ct_result.input_tokens = input_tokens + output_tokens + client.messages.count_tokens = AsyncMock(return_value=ct_result) + return client + + +# --------------------------------------------------------------------------- +# TestConversationManager +# --------------------------------------------------------------------------- + + +class TestConversationManager: + # --- Constructor validation --- + + def test_empty_model_raises(self): + client = _make_sync_client() + with pytest.raises(ValueError, match="model"): + ConversationManager(client, model="", max_tokens=512) + + def test_zero_max_tokens_raises(self): + client = _make_sync_client() + with pytest.raises(ValueError, match="max_tokens"): + ConversationManager(client, model="claude-3", max_tokens=0) + + def test_negative_max_tokens_raises(self): + client = _make_sync_client() + with pytest.raises(ValueError, match="max_tokens"): + ConversationManager(client, model="claude-3", max_tokens=-1) + + def test_negative_context_window_limit_raises(self): + client = _make_sync_client() + with pytest.raises(ValueError, match="context_window_limit"): + ConversationManager( + client, model="claude-3", max_tokens=512, context_window_limit=-1 + ) + + def test_zero_context_window_limit_raises(self): + client = _make_sync_client() + with pytest.raises(ValueError, match="context_window_limit"): + ConversationManager( + client, model="claude-3", max_tokens=512, context_window_limit=0 + ) + + def test_invalid_token_budget_headroom_negative(self): + client = _make_sync_client() + with pytest.raises(ValueError, match="token_budget_headroom"): + ConversationManager( + client, + model="claude-3", + max_tokens=512, + token_budget_headroom=-0.1, + ) + + def test_invalid_token_budget_headroom_one(self): + client = _make_sync_client() + with pytest.raises(ValueError, match="token_budget_headroom"): + ConversationManager( + client, + model="claude-3", + max_tokens=512, + token_budget_headroom=1.0, + ) + + def test_invalid_token_budget_headroom_greater_than_one(self): + client = _make_sync_client() + with pytest.raises(ValueError, match="token_budget_headroom"): + ConversationManager( + client, + model="claude-3", + max_tokens=512, + token_budget_headroom=1.5, + ) + + def test_valid_token_budget_headroom_zero(self): + client = _make_sync_client() + mgr = ConversationManager( + client, model="claude-3", max_tokens=512, token_budget_headroom=0.0 + ) + assert mgr is not None + + # --- add_user_message --- + + def test_add_user_message_appends(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.add_user_message("Hi") + assert len(mgr.history) == 1 + assert mgr.history[0] == {"role": "user", "content": "Hi"} + + def test_add_user_message_empty_string_raises(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + with pytest.raises(ValueError): + mgr.add_user_message("") + + def test_add_user_message_empty_list_raises(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + with pytest.raises(ValueError): + mgr.add_user_message([]) + + def test_add_user_message_list_content(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + blocks = [{"type": "text", "text": "Hello"}] + mgr.add_user_message(blocks) + assert mgr.history[0]["content"] == blocks + + # --- get_response --- + + def test_get_response_calls_api_once(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.get_response("Hello") + client.messages.create.assert_called_once() + + def test_get_response_returns_message(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + response = mgr.get_response("Hello") + assert response is client.messages.create.return_value + + def test_get_response_appends_assistant_turn(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.get_response("Hello") + assert len(mgr.history) == 2 + assert mgr.history[1]["role"] == "assistant" + + def test_get_response_with_prestaged_message(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.add_user_message("Pre-staged question") + mgr.get_response() # no content arg + assert len(mgr.history) == 2 + client.messages.create.assert_called_once() + + def test_multi_turn_four_messages(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.get_response("First question") + mgr.get_response("Second question") + assert len(mgr.history) == 4 + assert client.messages.create.call_count == 2 + + def test_last_usage_none_initially(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + assert mgr.last_usage is None + + def test_last_usage_populated_after_call(self): + client = _make_sync_client(input_tokens=200, output_tokens=75) + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.get_response("Hi") + assert mgr.last_usage is not None + assert mgr.last_usage.input_tokens == 200 + assert mgr.last_usage.output_tokens == 75 + + def test_kwargs_forwarded_to_create(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.get_response("Hi", temperature=0.5) + _, kwargs = client.messages.create.call_args + assert kwargs.get("temperature") == 0.5 + + def test_system_prompt_passed_when_set(self): + client = _make_sync_client() + mgr = ConversationManager( + client, model="claude-3", max_tokens=512, system="You are a pirate." + ) + mgr.get_response("Arrr") + _, kwargs = client.messages.create.call_args + assert kwargs.get("system") == "You are a pirate." + + def test_system_prompt_omitted_when_none(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.get_response("Hello") + _, kwargs = client.messages.create.call_args + assert "system" not in kwargs + + def test_no_staged_message_raises(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + with pytest.raises(ValueError): + mgr.get_response() # history is empty + + def test_history_last_not_user_raises(self): + """Calling get_response twice without a new user message should fail.""" + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.get_response("Hello") # history ends with assistant + with pytest.raises(ValueError): + mgr.get_response() # no new user message + + # --- history property --- + + def test_history_returns_copy(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.add_user_message("Hi") + h = mgr.history + h.append({"role": "user", "content": "mutated"}) + assert len(mgr.history) == 1 # internal state unchanged + + # --- reset --- + + def test_reset_clears_history(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.get_response("Hello") + mgr.reset() + assert mgr.history == [] + + def test_reset_clears_last_usage(self): + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.get_response("Hello") + mgr.reset() + assert mgr.last_usage is None + + def test_reset_preserves_model_and_system(self): + client = _make_sync_client() + mgr = ConversationManager( + client, model="claude-3", max_tokens=512, system="sys" + ) + mgr.get_response("Hello") + mgr.reset() + assert mgr._model == "claude-3" + assert mgr._system == "sys" + + # --- Truncation --- + + def test_truncation_no_op_when_limit_none(self): + client = _make_sync_client(input_tokens=10000, output_tokens=5000) + mgr = ConversationManager( + client, model="claude-3", max_tokens=512, context_window_limit=None + ) + mgr.get_response("Hi") + assert len(mgr.history) == 2 # no truncation happened + + def test_truncation_no_op_when_under_threshold(self): + # threshold = 1000 * 0.9 = 900; usage = 100+50 = 150 < 900 + client = _make_sync_client(input_tokens=100, output_tokens=50) + mgr = ConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + ) + mgr.get_response("First") + mgr.get_response("Second") + # After first turn: usage=150, threshold=900 → no truncation on second call + assert len(mgr.history) == 4 + + def test_truncation_drops_oldest_pair(self): + # After turn 1: usage = 800+100 = 900 tokens + # threshold = 1000 * 0.9 = 900 → 900 >= 900 → must truncate on turn 2 + client = _make_sync_client(input_tokens=800, output_tokens=100) + mgr = ConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + ) + mgr.get_response("First") # 2 messages, usage=900 + mgr.get_response("Second") # before API call: estimated=900 >= 900 → drop first pair + # After truncation the first pair is removed; then new user msg is already present, + # response appended → 2 messages total + assert len(mgr.history) == 2 + + def test_truncation_drops_multiple_pairs(self): + # Build up 3 turns of history, then on 4th call the estimate is very high + # We'll do this by controlling what count_tokens returns via accurate mode + client = _make_sync_client(input_tokens=100, output_tokens=50) + # Override count_tokens to return a high value first, then low + ct_high = MagicMock() + ct_high.input_tokens = 950 # above threshold + ct_medium = MagicMock() + ct_medium.input_tokens = 950 # still above + ct_low = MagicMock() + ct_low.input_tokens = 800 # below threshold = 1000*0.9=900 + + client.messages.count_tokens.side_effect = [ct_high, ct_medium, ct_low] + + mgr = ConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + accurate_token_counting=True, + ) + # Seed history with 2 pairs manually (4 messages) + mgr._history = [ + {"role": "user", "content": "q1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "q2"}, + {"role": "assistant", "content": "a2"}, + {"role": "user", "content": "q3"}, # current user message + ] + + mgr.get_response() # no content, history already ends with user + # ct_high → drop pair 1 → ct_medium → drop pair 2 → ct_low → stop + # Remaining: q3 + assistant response → 2 messages + assert len(mgr.history) == 2 + + def test_truncation_raises_when_single_pair_exceeds_limit(self): + # Single user message + no prior assistant → cannot drop + client = _make_sync_client(input_tokens=950, output_tokens=100) + mgr = ConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + ) + # Seed last_usage so heuristic has data + mgr._last_usage = _make_usage(input_tokens=950, output_tokens=100) + mgr._history = [{"role": "user", "content": "single large message"}] + + with pytest.raises(ValueError, match="Cannot truncate further"): + mgr.get_response() + + def test_no_truncation_on_first_call_heuristic(self): + # last_usage is None on first call → _estimate_tokens returns None → skip + client = _make_sync_client(input_tokens=999, output_tokens=999) + mgr = ConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=100, # very small limit + token_budget_headroom=0.10, + ) + # Should not raise — first call skips truncation + mgr.get_response("Hello") + assert len(mgr.history) == 2 + + def test_accurate_mode_count_tokens_called(self): + client = _make_sync_client(input_tokens=50, output_tokens=25) + # count_tokens returns 50 (below threshold=900 for limit=1000) + ct = MagicMock() + ct.input_tokens = 50 + client.messages.count_tokens.return_value = ct + + mgr = ConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + accurate_token_counting=True, + ) + mgr.get_response("Hi") + client.messages.count_tokens.assert_called_once() + + def test_accurate_mode_drops_pairs_until_under_threshold(self): + client = _make_sync_client(input_tokens=100, output_tokens=50) + ct_high = MagicMock() + ct_high.input_tokens = 950 + ct_low = MagicMock() + ct_low.input_tokens = 800 + + client.messages.count_tokens.side_effect = [ct_high, ct_low] + + mgr = ConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + accurate_token_counting=True, + ) + mgr._history = [ + {"role": "user", "content": "q1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "q2"}, + ] + mgr.get_response() + # q1+a1 dropped, q2+assistant appended → 2 messages + assert len(mgr.history) == 2 + assert client.messages.count_tokens.call_count == 2 + + # --- Role-alternation guard --- + + def test_add_user_message_consecutive_raises(self): + """Adding a second user message without an assistant reply should fail.""" + client = _make_sync_client() + mgr = ConversationManager(client, model="claude-3", max_tokens=512) + mgr.add_user_message("first") + with pytest.raises(ValueError, match="alternation"): + mgr.add_user_message("second") + + def test_truncation_invariant_violation_raises(self): + """If history starts with broken alternation, truncation should raise.""" + client = _make_sync_client(input_tokens=800, output_tokens=100) + mgr = ConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + ) + mgr._last_usage = _make_usage(input_tokens=800, output_tokens=100) + # Manually create broken history (assistant first) + mgr._history = [ + {"role": "assistant", "content": "bad"}, + {"role": "user", "content": "q1"}, + {"role": "user", "content": "q2"}, + ] + with pytest.raises(ValueError, match="alternation invariant"): + mgr.get_response() + + +# --------------------------------------------------------------------------- +# TestAsyncConversationManager +# --------------------------------------------------------------------------- + + +class TestAsyncConversationManager: + # --- Constructor validation --- + + def test_empty_model_raises(self): + client = _make_async_client() + with pytest.raises(ValueError, match="model"): + AsyncConversationManager(client, model="", max_tokens=512) + + def test_zero_max_tokens_raises(self): + client = _make_async_client() + with pytest.raises(ValueError, match="max_tokens"): + AsyncConversationManager(client, model="claude-3", max_tokens=0) + + def test_negative_context_window_limit_raises(self): + client = _make_async_client() + with pytest.raises(ValueError, match="context_window_limit"): + AsyncConversationManager( + client, model="claude-3", max_tokens=512, context_window_limit=-5 + ) + + def test_invalid_token_budget_headroom(self): + client = _make_async_client() + with pytest.raises(ValueError, match="token_budget_headroom"): + AsyncConversationManager( + client, + model="claude-3", + max_tokens=512, + token_budget_headroom=1.0, + ) + + # --- get_response --- + + @pytest.mark.asyncio + async def test_get_response_calls_api_once(self): + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + await mgr.get_response("Hello") + client.messages.create.assert_awaited_once() + + @pytest.mark.asyncio + async def test_get_response_returns_message(self): + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + response = await mgr.get_response("Hello") + assert response is client.messages.create.return_value + + @pytest.mark.asyncio + async def test_get_response_appends_assistant_turn(self): + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + await mgr.get_response("Hello") + assert len(mgr.history) == 2 + assert mgr.history[1]["role"] == "assistant" + + @pytest.mark.asyncio + async def test_multi_turn(self): + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + await mgr.get_response("First") + await mgr.get_response("Second") + assert len(mgr.history) == 4 + + @pytest.mark.asyncio + async def test_last_usage_none_initially(self): + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + assert mgr.last_usage is None + + @pytest.mark.asyncio + async def test_last_usage_populated_after_call(self): + client = _make_async_client(input_tokens=300, output_tokens=60) + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + await mgr.get_response("Hi") + assert mgr.last_usage.input_tokens == 300 + assert mgr.last_usage.output_tokens == 60 + + @pytest.mark.asyncio + async def test_kwargs_forwarded(self): + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + await mgr.get_response("Hi", temperature=0.7) + _, kwargs = client.messages.create.call_args + assert kwargs.get("temperature") == 0.7 + + @pytest.mark.asyncio + async def test_system_prompt_passed(self): + client = _make_async_client() + mgr = AsyncConversationManager( + client, model="claude-3", max_tokens=512, system="Be concise." + ) + await mgr.get_response("Hello") + _, kwargs = client.messages.create.call_args + assert kwargs.get("system") == "Be concise." + + @pytest.mark.asyncio + async def test_no_staged_message_raises(self): + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + with pytest.raises(ValueError): + await mgr.get_response() + + @pytest.mark.asyncio + async def test_history_returns_copy(self): + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + mgr.add_user_message("Hi") + h = mgr.history + h.append({"role": "user", "content": "injected"}) + assert len(mgr.history) == 1 + + @pytest.mark.asyncio + async def test_reset(self): + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + await mgr.get_response("Hello") + mgr.reset() + assert mgr.history == [] + assert mgr.last_usage is None + assert mgr._model == "claude-3" + + @pytest.mark.asyncio + async def test_no_truncation_on_first_call_heuristic(self): + client = _make_async_client(input_tokens=999, output_tokens=999) + mgr = AsyncConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=100, + token_budget_headroom=0.10, + ) + await mgr.get_response("Hello") + assert len(mgr.history) == 2 + + @pytest.mark.asyncio + async def test_accurate_mode_count_tokens_called(self): + client = _make_async_client(input_tokens=50, output_tokens=25) + ct = MagicMock() + ct.input_tokens = 50 + client.messages.count_tokens = AsyncMock(return_value=ct) + + mgr = AsyncConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + accurate_token_counting=True, + ) + await mgr.get_response("Hi") + client.messages.count_tokens.assert_awaited_once() + + @pytest.mark.asyncio + async def test_truncation_drops_oldest_pair(self): + client = _make_async_client(input_tokens=800, output_tokens=100) + mgr = AsyncConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + ) + await mgr.get_response("First") # usage=900 + await mgr.get_response("Second") # estimated=900 >= 900 → drop first pair + assert len(mgr.history) == 2 + + @pytest.mark.asyncio + async def test_truncation_raises_when_cannot_truncate(self): + client = _make_async_client(input_tokens=950, output_tokens=100) + mgr = AsyncConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + ) + mgr._last_usage = _make_usage(input_tokens=950, output_tokens=100) + mgr._history = [{"role": "user", "content": "single large message"}] + + with pytest.raises(ValueError, match="Cannot truncate further"): + await mgr.get_response() + + # --- Role-alternation guard --- + + def test_add_user_message_consecutive_raises(self): + """Adding a second user message without an assistant reply should fail.""" + client = _make_async_client() + mgr = AsyncConversationManager(client, model="claude-3", max_tokens=512) + mgr.add_user_message("first") + with pytest.raises(ValueError, match="alternation"): + mgr.add_user_message("second") + + @pytest.mark.asyncio + async def test_truncation_invariant_violation_raises(self): + """If history starts with broken alternation, truncation should raise.""" + client = _make_async_client(input_tokens=800, output_tokens=100) + mgr = AsyncConversationManager( + client, + model="claude-3", + max_tokens=512, + context_window_limit=1000, + token_budget_headroom=0.10, + ) + mgr._last_usage = _make_usage(input_tokens=800, output_tokens=100) + mgr._history = [ + {"role": "assistant", "content": "bad"}, + {"role": "user", "content": "q1"}, + {"role": "user", "content": "q2"}, + ] + with pytest.raises(ValueError, match="alternation invariant"): + await mgr.get_response()