diff --git a/openhands-agent-server/openhands/agent_server/pub_sub.py b/openhands-agent-server/openhands/agent_server/pub_sub.py index 7d2be6abc7..5522ed4f2a 100644 --- a/openhands-agent-server/openhands/agent_server/pub_sub.py +++ b/openhands-agent-server/openhands/agent_server/pub_sub.py @@ -62,16 +62,26 @@ def unsubscribe(self, subscriber_id: UUID) -> bool: async def __call__(self, event: T) -> None: """Invoke all registered callbacks with the given event. - Each callback is invoked in its own try/catch block to prevent - one failing callback from affecting others. + Subscribers are notified concurrently so a slow client cannot + block delivery to others. Each callback runs in its own + error-handling wrapper to preserve fault isolation. Args: event: The event to pass to all callbacks """ - for subscriber_id, subscriber in list(self._subscribers.items()): + subscribers = list(self._subscribers.items()) + if not subscribers: + return + + async def _notify(subscriber_id: UUID, subscriber: Subscriber[T]): try: await subscriber(event) except Exception as e: - logger.error(f"Error in subscriber {subscriber_id}: {e}", exc_info=True) + logger.error( + f"Error in subscriber {subscriber_id}: {e}", + exc_info=True, + ) + + await asyncio.gather(*[_notify(sid, sub) for sid, sub in subscribers]) async def close(self): await asyncio.gather( diff --git a/tests/agent_server/test_pub_sub.py b/tests/agent_server/test_pub_sub.py index c4d0e87378..e5bffcbcc7 100644 --- a/tests/agent_server/test_pub_sub.py +++ b/tests/agent_server/test_pub_sub.py @@ -84,14 +84,21 @@ def unsubscribe(self, subscriber_id: UUID) -> bool: async def __call__(self, event) -> None: """Invoke all registered callbacks with the given event.""" - for subscriber_id, subscriber in list(self._subscribers.items()): + subscribers = list(self._subscribers.items()) + if not subscribers: + return + + async def _notify(subscriber_id, subscriber): try: await subscriber(event) except Exception as e: self._logger.error( - f"Error in subscriber {subscriber_id}: {e}", exc_info=True + f"Error in subscriber {subscriber_id}: {e}", + exc_info=True, ) + await asyncio.gather(*[_notify(sid, sub) for sid, sub in subscribers]) + async def close(self): await asyncio.gather( *[subscriber.close() for subscriber in self._subscribers.values()], @@ -424,6 +431,41 @@ async def test_call_with_subscriber_error_isolation( assert pubsub._logger.error_calls[0][1] is True # exc_info=True +class _TimedSubscriber(SubscriberForTesting): + """Subscriber that records delivery wall-time after an artificial delay.""" + + def __init__(self, name: str, delay: float, log: list[tuple[str, float]]): + self.name = name + self.delay = delay + self.log = log + + async def __call__(self, event): + start = asyncio.get_event_loop().time() + await asyncio.sleep(self.delay) + self.log.append((self.name, asyncio.get_event_loop().time() - start)) + + +class TestPubSubConcurrentDispatch: + """Test that __call__ dispatches to subscribers concurrently.""" + + @pytest.mark.asyncio + async def test_slow_subscriber_does_not_block_others(self, pubsub): + """A slow subscriber must not delay delivery to faster ones.""" + delivery_log: list[tuple[str, float]] = [] + + pubsub.subscribe(_TimedSubscriber("slow", 0.2, delivery_log)) + pubsub.subscribe(_TimedSubscriber("fast", 0.0, delivery_log)) + + start = asyncio.get_event_loop().time() + await pubsub(MockEvent()) + elapsed = asyncio.get_event_loop().time() - start + + # Both subscribers were called + assert len(delivery_log) == 2 + # Wall time ≈ 0.2s (concurrent), not ≈ 0.2s+ (sequential) + assert elapsed < 0.3 + + class TestPubSubEventIsolation: """Test cases ensuring removed subscribers don't receive events."""