Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions openhands-agent-server/openhands/agent_server/pub_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,26 @@ def unsubscribe(self, subscriber_id: UUID) -> bool:

async def __call__(self, event: T) -> None:
"""Invoke all registered callbacks with the given event.
Each callback is invoked in its own try/catch block to prevent
one failing callback from affecting others.
Subscribers are notified concurrently so a slow client cannot
block delivery to others. Each callback runs in its own
error-handling wrapper to preserve fault isolation.
Args:
event: The event to pass to all callbacks
"""
for subscriber_id, subscriber in list(self._subscribers.items()):
subscribers = list(self._subscribers.items())
if not subscribers:
return

async def _notify(subscriber_id: UUID, subscriber: Subscriber[T]):
try:
await subscriber(event)
except Exception as e:
logger.error(f"Error in subscriber {subscriber_id}: {e}", exc_info=True)
logger.error(
f"Error in subscriber {subscriber_id}: {e}",
exc_info=True,
)

await asyncio.gather(*[_notify(sid, sub) for sid, sub in subscribers])

async def close(self):
await asyncio.gather(
Expand Down
46 changes: 44 additions & 2 deletions tests/agent_server/test_pub_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,21 @@ def unsubscribe(self, subscriber_id: UUID) -> bool:

async def __call__(self, event) -> None:
"""Invoke all registered callbacks with the given event."""
for subscriber_id, subscriber in list(self._subscribers.items()):
subscribers = list(self._subscribers.items())
if not subscribers:
return

async def _notify(subscriber_id, subscriber):
try:
await subscriber(event)
except Exception as e:
self._logger.error(
f"Error in subscriber {subscriber_id}: {e}", exc_info=True
f"Error in subscriber {subscriber_id}: {e}",
exc_info=True,
)

await asyncio.gather(*[_notify(sid, sub) for sid, sub in subscribers])

async def close(self):
await asyncio.gather(
*[subscriber.close() for subscriber in self._subscribers.values()],
Expand Down Expand Up @@ -424,6 +431,41 @@ async def test_call_with_subscriber_error_isolation(
assert pubsub._logger.error_calls[0][1] is True # exc_info=True


class _TimedSubscriber(SubscriberForTesting):
"""Subscriber that records delivery wall-time after an artificial delay."""

def __init__(self, name: str, delay: float, log: list[tuple[str, float]]):
self.name = name
self.delay = delay
self.log = log

async def __call__(self, event):
start = asyncio.get_event_loop().time()
await asyncio.sleep(self.delay)
self.log.append((self.name, asyncio.get_event_loop().time() - start))


class TestPubSubConcurrentDispatch:
"""Test that __call__ dispatches to subscribers concurrently."""

@pytest.mark.asyncio
async def test_slow_subscriber_does_not_block_others(self, pubsub):
"""A slow subscriber must not delay delivery to faster ones."""
delivery_log: list[tuple[str, float]] = []

pubsub.subscribe(_TimedSubscriber("slow", 0.2, delivery_log))
pubsub.subscribe(_TimedSubscriber("fast", 0.0, delivery_log))

start = asyncio.get_event_loop().time()
await pubsub(MockEvent())
elapsed = asyncio.get_event_loop().time() - start

# Both subscribers were called
assert len(delivery_log) == 2
# Wall time ≈ 0.2s (concurrent), not ≈ 0.2s+ (sequential)
assert elapsed < 0.3


class TestPubSubEventIsolation:
"""Test cases ensuring removed subscribers don't receive events."""

Expand Down
Loading