diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 8f477f9223..9b28882836 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -509,16 +509,27 @@ async def agent_wrapper(**kwargs: Any) -> str: # Extract the input from kwargs using the specified arg_name input_text = kwargs.get(arg_name, "") - # Forward runtime context kwargs, excluding arg_name and conversation_id. + # Extract conversation_id forwarded from parent agent's tool invocation loop + parent_conversation_id = kwargs.get("conversation_id") + + # Forward runtime context kwargs, excluding arg_name, conversation_id, and options. forwarded_kwargs = {k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options")} + # Pass parent's conversation_id via additional_function_arguments so it reaches + # the sub-agent's tools through **kwargs for correlation purposes. + # We do NOT pass it via chat options because the sub-agent's chat client should + # start its own conversation, not try to continue the parent's. + run_options: dict[str, Any] | None = None + if parent_conversation_id: + run_options = {"additional_function_arguments": {"parent_conversation_id": parent_conversation_id}} + if stream_callback is None: # Use non-streaming mode - return (await self.run(input_text, stream=False, **forwarded_kwargs)).text + return (await self.run(input_text, stream=False, options=run_options, **forwarded_kwargs)).text # Use streaming mode - accumulate updates and create final response response_updates: list[AgentResponseUpdate] = [] - async for update in self.run(input_text, stream=True, **forwarded_kwargs): + async for update in self.run(input_text, stream=True, options=run_options, **forwarded_kwargs): response_updates.append(update) if is_async_callback: await stream_callback(update) # type: ignore[misc] diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 303699572c..bd77858792 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -972,7 +972,10 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: # Recursively build field definitions for the nested model nested_field_definitions: dict[str, Any] = {} - for nested_prop_name, nested_prop_details in nested_properties.items(): + for ( + nested_prop_name, + nested_prop_details, + ) in nested_properties.items(): nested_prop_details = ( json.loads(nested_prop_details) if isinstance(nested_prop_details, str) @@ -1417,11 +1420,11 @@ async def _auto_invoke_function( parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {}) # Filter out internal framework kwargs before passing to tools. - # conversation_id is an internal tracking ID that should not be forwarded to tools. + # conversation_id is forwarded so agent-as-tool wrappers can correlate sub-agent conversations. runtime_kwargs: dict[str, Any] = { key: value for key, value in (custom_args or {}).items() - if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} + if key not in {"_function_middleware_pipeline", "middleware"} } try: if not tool._schema_supplied and tool.input_model is not None: @@ -2100,7 +2103,7 @@ def get_response( max_errors: int = self.function_invocation_configuration["max_consecutive_errors_per_request"] # type: ignore[assignment] additional_function_arguments: dict[str, Any] = {} if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined] - additional_function_arguments = additional_opts # type: ignore + additional_function_arguments = additional_opts # type: ignore[assignment] execute_function_calls = partial( _execute_function_calls, custom_args=additional_function_arguments, @@ -2162,6 +2165,7 @@ async def _get_response() -> ChatResponse: if response.conversation_id is not None: _update_conversation_id(kwargs, response.conversation_id, mutable_options) + additional_function_arguments["conversation_id"] = response.conversation_id prepped_messages = [] result = await _process_function_requests( @@ -2296,6 +2300,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: if response.conversation_id is not None: _update_conversation_id(kwargs, response.conversation_id, mutable_options) + additional_function_arguments["conversation_id"] = response.conversation_id prepped_messages = [] result = await _process_function_requests( diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index a857682fe2..37f3944c32 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import contextlib -from collections.abc import AsyncIterable, MutableSequence +from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence from typing import Any from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 @@ -23,10 +23,14 @@ Message, SupportsAgentRun, SupportsChatGetResponse, + agent_middleware, tool, ) from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name from agent_framework._mcp import MCPTool +from agent_framework._middleware import AgentContext + +from .conftest import MockChatClient def test_agent_session_type(agent_session: AgentSession) -> None: @@ -707,6 +711,51 @@ async def test_chat_agent_as_tool_name_sanitization(client: SupportsChatGetRespo assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}" +async def test_chat_agent_as_tool_propagates_conversation_id(client: SupportsChatGetResponse) -> None: + """Test that as_tool passes parent_conversation_id to sub-agent via additional_function_arguments.""" + mock_client: MockChatClient = client # type: ignore[assignment] + captured_options: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + captured_options.update(context.options or {}) + await call_next() + + mock_client.responses = [ + ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]), + ] + + sub_agent = Agent(client=mock_client, name="sub_agent", middleware=[capture_middleware]) + t = sub_agent.as_tool(name="delegate", arg_name="task") + + await t.invoke(arguments=t.input_model(task="Test delegation"), conversation_id="conv-parent-123") + + additional_args = captured_options.get("additional_function_arguments", {}) + assert additional_args.get("parent_conversation_id") == "conv-parent-123" + + +async def test_chat_agent_as_tool_no_conversation_id_when_absent(client: SupportsChatGetResponse) -> None: + """Test that as_tool does not inject additional_function_arguments when no conversation_id provided.""" + mock_client: MockChatClient = client # type: ignore[assignment] + captured_options: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + captured_options.update(context.options or {}) + await call_next() + + mock_client.responses = [ + ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]), + ] + + sub_agent = Agent(client=mock_client, name="sub_agent", middleware=[capture_middleware]) + t = sub_agent.as_tool(name="delegate", arg_name="task") + + await t.invoke(arguments=t.input_model(task="Test delegation"), user_id="user-789") + + assert "additional_function_arguments" not in captured_options + + async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None: """Test basic as_mcp_server functionality.""" agent = Agent(client=client, name="TestAgent", description="Test agent for MCP") diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index da8e907c40..e059d78560 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -339,3 +339,63 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Verify other kwargs were still forwarded assert captured_kwargs.get("api_token") == "secret-xyz-123" assert captured_kwargs.get("user_id") == "user-456" + + async def test_as_tool_propagates_conversation_id_via_options(self, client: MockChatClient) -> None: + """Test that parent_conversation_id from parent is passed to sub-agent's additional_function_arguments.""" + captured_options: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + captured_options.update(context.options or {}) + await call_next() + + # Setup mock response + client.responses = [ + ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]), + ] + + sub_agent = Agent( + client=client, + name="sub_agent", + middleware=[capture_middleware], + ) + + tool = sub_agent.as_tool(name="delegate", arg_name="task") + + await tool.invoke( + arguments=tool.input_model(task="Test delegation"), + conversation_id="conv-parent-123", + ) + + # Verify parent_conversation_id was passed via additional_function_arguments + additional_args = captured_options.get("additional_function_arguments", {}) + assert additional_args.get("parent_conversation_id") == "conv-parent-123" + + async def test_as_tool_no_conversation_id_when_absent(self, client: MockChatClient) -> None: + """Test that no parent_conversation_id is injected when parent has none.""" + captured_options: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + captured_options.update(context.options or {}) + await call_next() + + client.responses = [ + ChatResponse(messages=[Message(role="assistant", text="Sub-agent response")]), + ] + + sub_agent = Agent( + client=client, + name="sub_agent", + middleware=[capture_middleware], + ) + + tool = sub_agent.as_tool(name="delegate", arg_name="task") + + await tool.invoke( + arguments=tool.input_model(task="Test delegation"), + user_id="user-789", + ) + + # additional_function_arguments should not be in options when no conversation_id provided + assert "additional_function_arguments" not in captured_options diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index 8d74dc181d..9440f964ab 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -1464,4 +1464,66 @@ def test_nested_object_with_const_and_enum(): model(config={"type": "production", "level": "critical"}) -# endregion +async def test_auto_invoke_function_forwards_conversation_id() -> None: + """Test that _auto_invoke_function forwards conversation_id to tools that accept **kwargs.""" + from agent_framework._tools import _auto_invoke_function + + captured_kwargs: dict[str, Any] = {} + + @tool(approval_mode="never_require") + def capturing_tool(query: str, **kwargs: Any) -> str: + """A tool that captures kwargs.""" + captured_kwargs.update(kwargs) + return "ok" + + function_call = Content.from_function_call(name="capturing_tool", arguments='{"query": "test"}', call_id="call_1") + + await _auto_invoke_function( + function_call, + custom_args={"conversation_id": "conv-123", "other_arg": "value"}, + config={ + "enabled": True, + "max_iterations": 1, + "max_consecutive_errors_per_request": 3, + "include_detailed_errors": True, + }, + tool_map={"capturing_tool": capturing_tool}, + ) + + assert captured_kwargs.get("conversation_id") == "conv-123" + assert captured_kwargs.get("other_arg") == "value" + + +async def test_auto_invoke_function_still_filters_internal_kwargs() -> None: + """Test that _auto_invoke_function still filters _function_middleware_pipeline and middleware.""" + from agent_framework._tools import _auto_invoke_function + + captured_kwargs: dict[str, Any] = {} + + @tool(approval_mode="never_require") + def capturing_tool(query: str, **kwargs: Any) -> str: + """A tool that captures kwargs.""" + captured_kwargs.update(kwargs) + return "ok" + + function_call = Content.from_function_call(name="capturing_tool", arguments='{"query": "test"}', call_id="call_1") + + await _auto_invoke_function( + function_call, + custom_args={ + "_function_middleware_pipeline": "should_be_filtered", + "middleware": "should_be_filtered", + "conversation_id": "conv-456", + }, + config={ + "enabled": True, + "max_iterations": 1, + "max_consecutive_errors_per_request": 3, + "include_detailed_errors": True, + }, + tool_map={"capturing_tool": capturing_tool}, + ) + + assert "_function_middleware_pipeline" not in captured_kwargs + assert "middleware" not in captured_kwargs + assert captured_kwargs.get("conversation_id") == "conv-456" diff --git a/python/samples/02-agents/tools/agent_as_tool_conversation_id_propagation.py b/python/samples/02-agents/tools/agent_as_tool_conversation_id_propagation.py new file mode 100644 index 0000000000..54b69740f3 --- /dev/null +++ b/python/samples/02-agents/tools/agent_as_tool_conversation_id_propagation.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Any + +from agent_framework import Agent, FunctionInvocationContext, agent_middleware, tool +from agent_framework._middleware import AgentContext +from agent_framework.openai import OpenAIResponsesClient + +""" +Agent-as-Tool with Conversation ID Propagation Example + +Demonstrates how a parent agent's conversation_id is automatically propagated +to sub-agents wrapped as tools via as_tool(). This enables correlating +multi-agent conversations in storage systems. + +The middleware below is ONLY for observability — to print the conversation_id +at each stage so you can verify the propagation. It is NOT required for the +feature to work. The propagation happens automatically at the framework level. + +NOTE: conversation_id propagation requires a chat client that returns +conversation_id in its responses (e.g., OpenAI Responses API). +""" + + +# --- Observability middleware (not required for the feature) --- + + +@agent_middleware +async def log_conversation_id(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + """Prints the conversation_id seen by each agent. Only for demonstration.""" + agent_name = context.agent.name if hasattr(context.agent, "name") else "unknown" + conv_id = (context.options or {}).get("conversation_id") + additional = (context.options or {}).get("additional_function_arguments", {}) + parent_id = additional.get("parent_conversation_id") + print(f" [{agent_name}] conversation_id={conv_id}, parent_conversation_id={parent_id}") + await call_next() + + +async def log_tool_kwargs( + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], +) -> None: + """Prints the kwargs forwarded to a tool. Only for demonstration.""" + conv_id = context.kwargs.get("conversation_id") + parent_id = context.kwargs.get("parent_conversation_id") + print(f" [tool:{context.function.name}] conversation_id={conv_id}, parent_conversation_id={parent_id}") + await call_next() + + +# --- Application code --- + + +# This tool is NOT required for conversation_id propagation to work. +# It is included only to show the parent_conversation_id arriving via **kwargs. +# NOTE: approval_mode="never_require" is for sample brevity. +@tool(approval_mode="never_require") +def lookup_info(query: str, **kwargs: Any) -> str: + """Look up information for a given query. + + Args: + query: The search query. + + Keyword Args: + kwargs: Runtime context forwarded by the framework, including + parent_conversation_id if the parent agent propagated one. + + Returns: + The lookup result. + """ + parent_id = kwargs.get("parent_conversation_id") + return f"Results for '{query}' (tracked under parent conversation {parent_id})" + + +async def main() -> None: + print("=== Agent-as-Tool: Conversation ID Propagation ===\n") + + client = OpenAIResponsesClient() + + # Create a specialized research agent + researcher = Agent( + client=client, + name="ResearchAgent", + instructions="You are a research assistant. Use the lookup_info tool to find information.", + tools=[lookup_info], + middleware=[log_conversation_id], + function_middleware=[log_tool_kwargs], + ) + + # Wrap the research agent as a tool for the coordinator + research_tool = researcher.as_tool( + name="research", + description="Delegate research tasks to a specialized research agent", + arg_name="task", + arg_description="The research task to perform", + ) + + # Create coordinator with the same observability middleware + coordinator = Agent( + client=client, + name="CoordinatorAgent", + instructions=( + "You are a coordinator. When the user asks a question, delegate to the research tool to find the answer." + ), + tools=[research_tool], + middleware=[log_conversation_id], + function_middleware=[log_tool_kwargs], + ) + + # Run — watch the printed output to see conversation_id flow: + # 1. CoordinatorAgent gets a conversation_id from the API + # 2. The tool invocation forwards it to the research tool's **kwargs + # 3. ResearchAgent receives it as parent_conversation_id in its options + # 4. ResearchAgent's own tools see it via **kwargs + response = await coordinator.run("What are the latest developments in quantum computing?") + + print(f"\nCoordinator: {response.text}") + + +if __name__ == "__main__": + asyncio.run(main())