Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions src/anthropic/lib/bedrock/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import urllib.parse
from typing import Any, Union, Mapping, TypeVar
from typing_extensions import Self, override
from typing_extensions import Self, Literal, TypeAlias, override

import httpx

Expand All @@ -30,12 +30,38 @@
log: logging.Logger = logging.getLogger(__name__)

DEFAULT_VERSION = "bedrock-2023-05-31"
BEDROCK_PERFORMANCE_CONFIG_LATENCY_HEADER = "X-Amzn-Bedrock-PerformanceConfig-Latency"
BedrockPerformanceConfigLatency: TypeAlias = Literal["standard", "optimized"]

_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])


def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions:
def _apply_performance_config_latency_header(
options: FinalRequestOptions,
performance_config_latency: BedrockPerformanceConfigLatency | None,
) -> None:
if performance_config_latency is None:
return

if not is_given(options.headers):
headers = {}
else:
headers = {**options.headers}

for header in headers:
if header.lower() == BEDROCK_PERFORMANCE_CONFIG_LATENCY_HEADER.lower():
return

headers[BEDROCK_PERFORMANCE_CONFIG_LATENCY_HEADER] = performance_config_latency
options.headers = headers


def _prepare_options(
input_options: FinalRequestOptions,
*,
performance_config_latency: BedrockPerformanceConfigLatency | None = None,
) -> FinalRequestOptions:
options = model_copy(input_options, deep=True)

if is_dict(options.json_data):
Expand All @@ -57,6 +83,7 @@ def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions:
options.url = f"/model/{model}/invoke-with-response-stream"
else:
options.url = f"/model/{model}/invoke"
_apply_performance_config_latency_header(options, performance_config_latency)

if options.url.startswith("/v1/messages/batches"):
raise AnthropicError("The Batch API is not supported in Bedrock yet")
Expand Down Expand Up @@ -146,6 +173,7 @@ def __init__(
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
performance_config_latency: BedrockPerformanceConfigLatency | None = None,
# Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: httpx.Client | None = None,
# Enable or disable schema validation for data returned by the API.
Expand Down Expand Up @@ -182,6 +210,7 @@ def __init__(
self.aws_profile = aws_profile

self.aws_session_token = aws_session_token
self.performance_config_latency: BedrockPerformanceConfigLatency | None = performance_config_latency

if base_url is None:
base_url = os.environ.get("ANTHROPIC_BEDROCK_BASE_URL")
Expand Down Expand Up @@ -209,7 +238,7 @@ def _make_sse_decoder(self) -> AWSEventStreamDecoder:

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
return _prepare_options(options)
return _prepare_options(options, performance_config_latency=self.performance_config_latency)

@override
def _prepare_request(self, request: httpx.Request) -> None:
Expand Down Expand Up @@ -250,6 +279,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
performance_config_latency: BedrockPerformanceConfigLatency | None | NotGiven = NOT_GIVEN,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -285,6 +315,9 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
performance_config_latency=self.performance_config_latency
if isinstance(performance_config_latency, NotGiven)
else performance_config_latency,
**_extra_kwargs,
)

Expand All @@ -311,6 +344,7 @@ def __init__(
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
performance_config_latency: BedrockPerformanceConfigLatency | None = None,
# Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: httpx.AsyncClient | None = None,
# Enable or disable schema validation for data returned by the API.
Expand Down Expand Up @@ -347,6 +381,7 @@ def __init__(
self.aws_profile = aws_profile

self.aws_session_token = aws_session_token
self.performance_config_latency: BedrockPerformanceConfigLatency | None = performance_config_latency

if base_url is None:
base_url = os.environ.get("ANTHROPIC_BEDROCK_BASE_URL")
Expand Down Expand Up @@ -374,7 +409,7 @@ def _make_sse_decoder(self) -> AWSEventStreamDecoder:

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
return _prepare_options(options)
return _prepare_options(options, performance_config_latency=self.performance_config_latency)

@override
async def _prepare_request(self, request: httpx.Request) -> None:
Expand Down Expand Up @@ -415,6 +450,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
performance_config_latency: BedrockPerformanceConfigLatency | None | NotGiven = NOT_GIVEN,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -450,6 +486,9 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
performance_config_latency=self.performance_config_latency
if isinstance(performance_config_latency, NotGiven)
else performance_config_latency,
**_extra_kwargs,
)

Expand Down
103 changes: 103 additions & 0 deletions tests/lib/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
aws_access_key="example-access-key",
aws_secret_key="example-secret-key",
)
optimized_sync_client = AnthropicBedrock(
aws_region="us-east-1",
aws_access_key="example-access-key",
aws_secret_key="example-secret-key",
performance_config_latency="optimized",
)
optimized_async_client = AsyncAnthropicBedrock(
aws_region="us-east-1",
aws_access_key="example-access-key",
aws_secret_key="example-secret-key",
performance_config_latency="optimized",
)

BEDROCK_PERFORMANCE_CONFIG_LATENCY_HEADER = "X-Amzn-Bedrock-PerformanceConfig-Latency"


class MockRequestCall(Protocol):
Expand Down Expand Up @@ -166,6 +180,95 @@ def test_application_inference_profile(respx_mock: MockRouter) -> None:
)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
def test_performance_config_latency_unset_by_default(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock(
return_value=httpx.Response(200, json={"foo": "bar"}),
)

sync_client.messages.create(
max_tokens=1024,
messages=[{"role": "user", "content": "Say hello there!"}],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)
assert len(calls) == 1
assert BEDROCK_PERFORMANCE_CONFIG_LATENCY_HEADER not in calls[0].request.headers


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
def test_performance_config_latency_optimized(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock(
return_value=httpx.Response(200, json={"foo": "bar"}),
)

optimized_sync_client.messages.create(
max_tokens=1024,
messages=[{"role": "user", "content": "Say hello there!"}],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)
assert len(calls) == 1
assert calls[0].request.headers[BEDROCK_PERFORMANCE_CONFIG_LATENCY_HEADER] == "optimized"


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
@pytest.mark.asyncio()
async def test_performance_config_latency_optimized_async(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock(
return_value=httpx.Response(200, json={"foo": "bar"}),
)

await optimized_async_client.messages.create(
max_tokens=1024,
messages=[{"role": "user", "content": "Say hello there!"}],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)
assert len(calls) == 1
assert calls[0].request.headers[BEDROCK_PERFORMANCE_CONFIG_LATENCY_HEADER] == "optimized"


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
def test_performance_config_latency_optimized_streaming(respx_mock: MockRouter) -> None:
respx_mock.post(
re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke-with-response-stream")
).mock(
return_value=httpx.Response(200, content=b""),
)

stream = optimized_sync_client.messages.create(
max_tokens=1024,
messages=[{"role": "user", "content": "Say hello there!"}],
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
stream=True,
)
stream.response.close()

calls = cast("list[MockRequestCall]", respx_mock.calls)
assert len(calls) == 1
assert (
calls[0].request.url
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke-with-response-stream"
)
assert calls[0].request.headers[BEDROCK_PERFORMANCE_CONFIG_LATENCY_HEADER] == "optimized"


def test_copy_preserves_performance_config_latency() -> None:
copied = optimized_sync_client.copy()
assert copied.performance_config_latency == "optimized"

cleared = optimized_sync_client.copy(performance_config_latency=None)
assert cleared.performance_config_latency is None


sync_api_key_client = AnthropicBedrock(
aws_region="us-east-1",
api_key="test-api-key",
Expand Down