diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index ba534436d6..5f8c27b182 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -36,6 +36,7 @@ Role, TextContent, UriContent, + normalize_messages, prepend_agent_framework_to_user_agent, ) from agent_framework.observability import use_agent_instrumentation @@ -236,7 +237,7 @@ async def run_stream( Yields: An agent response item. """ - messages = self._normalize_messages(messages) + messages = normalize_messages(messages) a2a_message = self._prepare_message_for_a2a(messages[-1]) response_stream = self.client.send_message(a2a_message) diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index 606e1e83b6..7dc676b06a 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -13,6 +13,7 @@ ContextProvider, Role, TextContent, + normalize_messages, ) from agent_framework._pydantic import AFBaseSettings from agent_framework.exceptions import ServiceException, ServiceInitializationError @@ -237,7 +238,7 @@ async def run( thread = self.get_new_thread() thread.service_thread_id = await self._start_new_conversation() - input_messages = self._normalize_messages(messages) + input_messages = normalize_messages(messages) question = "\n".join([message.text for message in input_messages]) @@ -278,7 +279,7 @@ async def run_stream( thread = self.get_new_thread() thread.service_thread_id = await self._start_new_conversation() - input_messages = self._normalize_messages(messages) + input_messages = normalize_messages(messages) question = "\n".join([message.text for message in input_messages]) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 628ac7fb17..253eb81b8c 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -38,7 +38,7 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - Role, + normalize_messages, ) from .exceptions import AgentExecutionException, AgentInitializationError from .observability import use_agent_instrumentation @@ -498,21 +498,6 @@ async def agent_wrapper(**kwargs: Any) -> str: agent_tool._forward_runtime_kwargs = True # type: ignore return agent_tool - def _normalize_messages( - self, - messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, - ) -> list[ChatMessage]: - if messages is None: - return [] - - if isinstance(messages, str): - return [ChatMessage(role=Role.USER, text=messages)] - - if isinstance(messages, ChatMessage): - return [messages] - - return [ChatMessage(role=Role.USER, text=msg) if isinstance(msg, str) else msg for msg in messages] - # region ChatAgent @@ -797,7 +782,7 @@ async def run( # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) - input_messages = self._normalize_messages(messages) + input_messages = normalize_messages(messages) thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages( thread=thread, input_messages=input_messages, **kwargs ) @@ -925,7 +910,7 @@ async def run_stream( # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) - input_messages = self._normalize_messages(messages) + input_messages = normalize_messages(messages) thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages( thread=thread, input_messages=input_messages, **kwargs ) diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 0e26565c5a..aeafd91ac6 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypedDict, TypeVar from ._serialization import SerializationMixin -from ._types import AgentResponse, AgentResponseUpdate, ChatMessage, prepare_messages +from ._types import AgentResponse, AgentResponseUpdate, ChatMessage, normalize_messages, prepare_messages from .exceptions import MiddlewareException if TYPE_CHECKING: @@ -1225,7 +1225,7 @@ async def middleware_enabled_run( if chat_middlewares: kwargs["middleware"] = chat_middlewares - normalized_messages = self._normalize_messages(messages) + normalized_messages = normalize_messages(messages) # Execute with middleware if available if agent_pipeline.has_middlewares: @@ -1273,7 +1273,7 @@ def middleware_enabled_run_stream( if chat_middlewares: kwargs["middleware"] = chat_middlewares - normalized_messages = self._normalize_messages(messages) + normalized_messages = normalize_messages(messages) # Execute with middleware if available if agent_pipeline.has_middlewares: diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 4ee640304e..98e0967374 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -66,6 +66,7 @@ "UsageContent", "UsageDetails", "merge_chat_options", + "normalize_messages", "normalize_tools", "prepare_function_call_results", "prepend_instructions_to_messages", @@ -2495,6 +2496,22 @@ def prepare_messages( return return_messages +def normalize_messages( + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, +) -> list[ChatMessage]: + """Normalize message inputs to a list of ChatMessage objects.""" + if messages is None: + return [] + + if isinstance(messages, str): + return [ChatMessage(role=Role.USER, text=messages)] + + if isinstance(messages, ChatMessage): + return [messages] + + return [ChatMessage(role=Role.USER, text=msg) if isinstance(msg, str) else msg for msg in messages] + + def prepend_instructions_to_messages( messages: list[ChatMessage], instructions: str | Sequence[str] | None, diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 5cfea39287..4994cf4981 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -1902,3 +1902,59 @@ async def kwargs_middleware( assert modified_kwargs["max_tokens"] == 500 assert modified_kwargs["new_param"] == "added_by_middleware" assert modified_kwargs["custom_param"] == "test_value" # Should still be there + + +class TestMiddlewareWithProtocolOnlyAgent: + """Test use_agent_middleware with agents implementing only AgentProtocol.""" + + async def test_middleware_with_protocol_only_agent(self) -> None: + """Verify middleware works without BaseAgent inheritance for both run and run_stream.""" + from collections.abc import AsyncIterable + + from agent_framework import AgentProtocol, AgentResponse, AgentResponseUpdate, use_agent_middleware + + execution_order: list[str] = [] + + class TrackingMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("before") + await next(context) + execution_order.append("after") + + @use_agent_middleware + class ProtocolOnlyAgent: + """Minimal agent implementing only AgentProtocol, not inheriting from BaseAgent.""" + + def __init__(self): + self.id = "protocol-only-agent" + self.name = "Protocol Only Agent" + self.description = "Test agent" + self.middleware = [TrackingMiddleware()] + + async def run(self, messages=None, *, thread=None, **kwargs) -> AgentResponse: + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + def run_stream(self, messages=None, *, thread=None, **kwargs) -> AsyncIterable[AgentResponseUpdate]: + async def _stream(): + yield AgentResponseUpdate() + + return _stream() + + def get_new_thread(self, **kwargs): + return None + + agent = ProtocolOnlyAgent() + assert isinstance(agent, AgentProtocol) + + # Test run (non-streaming) + response = await agent.run("test message") + assert response is not None + assert execution_order == ["before", "after"] + + # Test run_stream (streaming) + execution_order.clear() + async for _ in agent.run_stream("test message"): + pass + assert execution_order == ["before", "after"]