Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
-->

## [UNRELEASED]

* Added support for built-in provider tools via a new `ToolBuiltIn` class. This enables provider-specific functionality like OpenAI's image generation to be registered and used as tools. Built-in tools pass raw provider definitions directly to the API rather than wrapping Python functions. (#214)
* `ChatGoogle()` gains basic support for image generation. (#214)

## [0.14.0] - 2025-12-09

### New features
Expand Down
3 changes: 2 additions & 1 deletion chatlas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ._provider_portkey import ChatPortkey
from ._provider_snowflake import ChatSnowflake
from ._tokens import token_usage
from ._tools import Tool, ToolRejectError
from ._tools import Tool, ToolBuiltIn, ToolRejectError
from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn

try:
Expand Down Expand Up @@ -84,6 +84,7 @@
"Provider",
"token_usage",
"Tool",
"ToolBuiltIn",
"ToolRejectError",
"Turn",
"UserTurn",
Expand Down
58 changes: 42 additions & 16 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from ._mcp_manager import MCPSessionManager
from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT
from ._tokens import compute_cost, get_token_pricing, tokens_log
from ._tools import Tool, ToolRejectError
from ._tools import Tool, ToolBuiltIn, ToolRejectError
from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn, user_turn
from ._typing_extensions import TypedDict, TypeGuard
from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
self.system_prompt = system_prompt
self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {}

self._tools: dict[str, Tool] = {}
self._tools: dict[str, Tool | ToolBuiltIn] = {}
self._on_tool_request_callbacks = CallbackManager()
self._on_tool_result_callbacks = CallbackManager()
self._current_display: Optional[MarkdownDisplay] = None
Expand Down Expand Up @@ -1880,7 +1880,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):

def register_tool(
self,
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool,
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | ToolBuiltIn,
*,
force: bool = False,
name: Optional[str] = None,
Expand Down Expand Up @@ -1982,23 +1982,30 @@ def add(a: int, b: int) -> int:
func.func, name=name, model=model, annotations=annotations
)
func = func.func
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
else:
if isinstance(func, ToolBuiltIn):
tool = func
else:
tool = Tool.from_func(
func, name=name, model=model, annotations=annotations
Comment thread
cpsievert marked this conversation as resolved.
)

tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
if tool.name in self._tools and not force:
raise ValueError(
f"Tool with name '{tool.name}' is already registered. "
"Set `force=True` to overwrite it."
)
self._tools[tool.name] = tool

def get_tools(self) -> list[Tool]:
def get_tools(self) -> list[Tool | ToolBuiltIn]:
"""
Get the list of registered tools.

Returns
-------
list[Tool]
A list of `Tool` instances that are currently registered with the chat.
list[Tool | ToolBuiltIn]
A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat.
"""
return list(self._tools.values())

Expand Down Expand Up @@ -2522,7 +2529,7 @@ def _submit_turns(
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str, None, None]:
if any(x._is_async for x in self._tools.values()):
if any(isinstance(x, Tool) and x._is_async for x in self._tools.values()):
raise ValueError("Cannot use async tools in a synchronous chat")

def emit(text: str | Content):
Expand Down Expand Up @@ -2683,15 +2690,24 @@ def _collect_all_kwargs(

def _invoke_tool(self, request: ContentToolRequest):
tool = self._tools.get(request.name)
func = tool.func if tool is not None else None

if func is None:
if tool is None:
yield self._handle_tool_error_result(
request,
error=RuntimeError("Unknown tool."),
)
return

if isinstance(tool, ToolBuiltIn):
yield self._handle_tool_error_result(
request,
error=RuntimeError(
f"Built-in tool '{request.name}' cannot be invoked directly. "
"It should be handled by the provider."
),
)
return

# First, invoke the request callbacks. If a ToolRejectError is raised,
# treat it like a tool failure (i.e., gracefully handle it).
result: ContentToolResult | None = None
Expand All @@ -2703,9 +2719,9 @@ def _invoke_tool(self, request: ContentToolRequest):

try:
if isinstance(request.arguments, dict):
res = func(**request.arguments)
res = tool.func(**request.arguments)
else:
res = func(request.arguments)
res = tool.func(request.arguments)

# Normalize res as a generator of results.
if not inspect.isgenerator(res):
Expand Down Expand Up @@ -2739,10 +2755,15 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
)
return

if tool._is_async:
func = tool.func
else:
func = wrap_async(tool.func)
if isinstance(tool, ToolBuiltIn):
yield self._handle_tool_error_result(
request,
error=RuntimeError(
f"Built-in tool '{request.name}' cannot be invoked directly. "
"It should be handled by the provider."
),
)
return

# First, invoke the request callbacks. If a ToolRejectError is raised,
# treat it like a tool failure (i.e., gracefully handle it).
Expand All @@ -2753,6 +2774,11 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
yield self._handle_tool_error_result(request, e)
return

if tool._is_async:
func = tool.func
else:
func = wrap_async(tool.func)

# Invoke the tool (if it hasn't been rejected).
try:
if isinstance(request.arguments, dict):
Expand Down
26 changes: 16 additions & 10 deletions chatlas/_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._typing_extensions import TypedDict

if TYPE_CHECKING:
from ._tools import Tool
from ._tools import Tool, ToolBuiltIn


class ToolAnnotations(TypedDict, total=False):
Expand Down Expand Up @@ -104,15 +104,21 @@ class ToolInfo(BaseModel):
annotations: Optional[ToolAnnotations] = None

@classmethod
def from_tool(cls, tool: "Tool") -> "ToolInfo":
"""Create a ToolInfo from a Tool instance."""
func_schema = tool.schema["function"]
return cls(
name=tool.name,
description=func_schema.get("description", ""),
parameters=func_schema.get("parameters", {}),
annotations=tool.annotations,
)
def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo":
"""Create a ToolInfo from a Tool or ToolBuiltIn instance."""
from ._tools import ToolBuiltIn

if isinstance(tool, ToolBuiltIn):
return cls(name=tool.name, description=tool.name, parameters={})
else:
# For regular tools, extract from schema
func_schema = tool.schema["function"]
return cls(
name=tool.name,
description=func_schema.get("description", ""),
parameters=func_schema.get("parameters", {}),
annotations=tool.annotations,
)


ContentTypeEnum = Literal[
Expand Down
6 changes: 3 additions & 3 deletions chatlas/_mcp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional, Sequence

from ._tools import Tool
from ._tools import Tool, ToolBuiltIn

if TYPE_CHECKING:
from mcp import ClientSession
Expand All @@ -23,7 +23,7 @@ class SessionInfo(ABC):

# Primary derived attributes
session: ClientSession | None = None
tools: dict[str, Tool] = field(default_factory=dict)
tools: dict[str, Tool | ToolBuiltIn] = field(default_factory=dict)

# Background task management
ready_event: asyncio.Event = field(default_factory=asyncio.Event)
Expand Down Expand Up @@ -74,7 +74,7 @@ async def request_tools(self) -> None:
tool_names = tool_names.difference(exclude)

# Apply namespace and convert to chatlas.Tool instances
self_tools: dict[str, Tool] = {}
self_tools: dict[str, Tool | ToolBuiltIn] = {}
for tool in response.tools:
if tool.name not in tool_names:
continue
Expand Down
18 changes: 9 additions & 9 deletions chatlas/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import BaseModel

from ._content import Content
from ._tools import Tool
from ._tools import Tool, ToolBuiltIn
from ._turn import AssistantTurn, Turn
from ._typing_extensions import NotRequired, TypedDict

Expand Down Expand Up @@ -162,7 +162,7 @@ def chat_perform(
*,
stream: Literal[False],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> ChatCompletionT: ...
Expand All @@ -174,7 +174,7 @@ def chat_perform(
*,
stream: Literal[True],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> Iterable[ChatCompletionChunkT]: ...
Expand All @@ -185,7 +185,7 @@ def chat_perform(
*,
stream: bool,
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
Expand All @@ -197,7 +197,7 @@ async def chat_perform_async(
*,
stream: Literal[False],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> ChatCompletionT: ...
Expand All @@ -209,7 +209,7 @@ async def chat_perform_async(
*,
stream: Literal[True],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> AsyncIterable[ChatCompletionChunkT]: ...
Expand All @@ -220,7 +220,7 @@ async def chat_perform_async(
*,
stream: bool,
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
Expand Down Expand Up @@ -259,15 +259,15 @@ def value_tokens(
def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
) -> int: ...

@abstractmethod
async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
) -> int: ...

Expand Down
Loading
Loading