Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
>
> The changes related to the Colang language and runtime have moved to [CHANGELOG-Colang](./CHANGELOG-Colang.md) file.

## [Unreleased]

### 🐛 Bug Fixes

- *(embeddings)* Synchronize batched embedding requests to avoid deadlocks ([#1476](https://github.com/NVIDIA-NeMo/Guardrails/issues/1476))

## [0.22.0] - 2026-05-22

### 🚀 Features
Expand Down Expand Up @@ -111,7 +117,6 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
- Restore original 2023-2026 copyright dates on moved files ([#1831](https://github.com/NVIDIA-NeMo/Guardrails/issues/1831))
- Include scripts in docker image ([#1902](https://github.com/NVIDIA-NeMo/Guardrails/issues/1902))


## [0.21.0] - 2026-03-12

### 🚀 Features
Expand Down
179 changes: 114 additions & 65 deletions nemoguardrails/embeddings/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ def __init__(
# Data structures for batching embedding requests
self._req_queue: Dict[int, str] = {}
self._req_results: Dict[int, List[float]] = {}
self._req_errors: Dict[int, BaseException] = {}
self._req_idx: int = 0
self._current_batch_finished_event: Optional[asyncio.Event] = None
self._current_batch_full_event: Optional[asyncio.Event] = None
# Stored so callers can cancel or inspect the active batch task (e.g. shutdown).
self._current_batch_task: Optional[asyncio.Task] = None
self._current_batch_submitted: asyncio.Event = asyncio.Event()
self._batch_lock: asyncio.Lock = asyncio.Lock()

# Initialize the batching configuration
self.use_batching = use_batching
Expand Down Expand Up @@ -187,74 +191,119 @@ async def build(self):
self._index.add_item(i, self._embeddings[i])
self._index.build(10)

async def _run_batch(self):
async def _run_batch(self, batch_full_event: asyncio.Event, batch_finished_event: asyncio.Event):
"""Runs the current batch of embeddings."""

# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
if self._current_batch_full_event is None or self._current_batch_finished_event is None:
raise RuntimeError("Batch events not initialized. This should not happen.")

done, pending = await asyncio.wait(
[
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
asyncio.create_task(self._current_batch_full_event.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()

# Reset the batch event
batch_event: asyncio.Event = self._current_batch_finished_event
self._current_batch_finished_event = None

# Create the actual batch to be send for computing
batch = []
batch_ids = list(self._req_queue.keys())
for req_id in batch_ids:
batch.append(self._req_queue[req_id])

# Empty the queue up to this point
self._req_queue = {}

# We allow other batches to start
self._current_batch_submitted.set()

# print(f"Running batch of length {len(batch)}")

# Compute the embeddings
embeddings = await self._get_embeddings(batch)
for i in range(len(embeddings)):
self._req_results[batch_ids[i]] = embeddings[i]

# Signal that the batch has finished processing
batch_event.set()
# Initialised here so the except handlers can safely iterate even if
# cancellation fires before the lock section populates batch_ids.
batch_ids: List[int] = []
try:
# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
_, pending = await asyncio.wait(
[
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
asyncio.create_task(batch_full_event.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()

async with self._batch_lock:
# Reset the active batch only if it has not already rolled over.
if self._current_batch_finished_event is batch_finished_event:
self._current_batch_finished_event = None
self._current_batch_full_event = None

# Create the actual batch to be sent for computing.
batch_ids = list(self._req_queue.keys())
batch = [self._req_queue[req_id] for req_id in batch_ids]

# Empty the queue up to this point.
self._req_queue = {}

# We allow other batches to start.
self._current_batch_submitted.set()

embeddings = await self._get_embeddings(batch)
if len(embeddings) < len(batch_ids):
shortage_exc = RuntimeError(
f"Embedding model returned {len(embeddings)} embeddings for {len(batch_ids)} inputs."
)
for req_id in batch_ids[len(embeddings) :]:
self._req_errors[req_id] = shortage_exc
for i in range(len(embeddings)):
self._req_results[batch_ids[i]] = embeddings[i]
except asyncio.CancelledError as exc:
# If cancelled before the lock section, batch_ids is still [].
# Snapshot and drain the queue without the lock — safe because asyncio
# is single-threaded and no await occurs between here and `raise`.
if not batch_ids:
batch_ids = list(self._req_queue.keys())
self._req_queue = {}
if self._current_batch_finished_event is batch_finished_event:
self._current_batch_finished_event = None
self._current_batch_full_event = None
for req_id in batch_ids:
if req_id not in self._req_results and req_id not in self._req_errors:
self._req_errors[req_id] = exc
raise
except Exception as exc:
for req_id in batch_ids:
if req_id not in self._req_results and req_id not in self._req_errors:
self._req_errors[req_id] = exc
finally:
# Unconditionally unblock full-queue waiters in case the lock section
# was never reached (early cancellation).
self._current_batch_submitted.set()
batch_finished_event.set()
Comment thread
ppradyoth marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment on lines +236 to +258
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Early cancellation leaves callers with KeyError and blocks full-queue waiters

batch_ids is initialised as [] on line 198 and is only populated inside the async with self._batch_lock section (line 218). If the outer task is cancelled while awaiting asyncio.wait([sleep, batch_full_event.wait()]) — e.g. via _current_batch_task.cancel() during shutdown before the hold timer expires — CancelledError is thrown before the lock section ever runs. The except asyncio.CancelledError handler then iterates an empty batch_ids, so no error is stored for any of the requests that are still in _req_queue. finally sets batch_finished_event, those callers wake up, find nothing in _req_results or _req_errors, and crash with KeyError from self._req_results.pop(req_id).

A second failure follows: _current_batch_submitted.set() is only called inside the lock section. Any caller blocked on await batch_submitted_event.wait() because the queue was at capacity will deadlock — nothing will ever set that event again, and the new-batch creation path is also stuck because _current_batch_finished_event and _current_batch_full_event are left pointing to the already-signalled, stale events from the cancelled batch.

The existing test test_batch_get_embeddings_propagates_cancelled_batch_task avoids this path by waiting for encoding_started before cancelling, guaranteeing the lock section has already run. A minimal fix is to handle remaining queue entries in the except handlers (safe without the lock because asyncio is single-threaded between await points) and unconditionally call self._current_batch_submitted.set() in the finally block.

Prompt To Fix With AI
This is a comment left during a code review.
Path: nemoguardrails/embeddings/basic.py
Line: 236-246

Comment:
**Early cancellation leaves callers with `KeyError` and blocks full-queue waiters**

`batch_ids` is initialised as `[]` on line 198 and is only populated inside the `async with self._batch_lock` section (line 218). If the outer task is cancelled while awaiting `asyncio.wait([sleep, batch_full_event.wait()])` — e.g. via `_current_batch_task.cancel()` during shutdown before the hold timer expires — `CancelledError` is thrown before the lock section ever runs. The `except asyncio.CancelledError` handler then iterates an empty `batch_ids`, so no error is stored for any of the requests that are still in `_req_queue`. `finally` sets `batch_finished_event`, those callers wake up, find nothing in `_req_results` or `_req_errors`, and crash with `KeyError` from `self._req_results.pop(req_id)`.

A second failure follows: `_current_batch_submitted.set()` is only called inside the lock section. Any caller blocked on `await batch_submitted_event.wait()` because the queue was at capacity will deadlock — nothing will ever set that event again, and the new-batch creation path is also stuck because `_current_batch_finished_event` and `_current_batch_full_event` are left pointing to the already-signalled, stale events from the cancelled batch.

The existing test `test_batch_get_embeddings_propagates_cancelled_batch_task` avoids this path by waiting for `encoding_started` before cancelling, guaranteeing the lock section has already run. A minimal fix is to handle remaining queue entries in the except handlers (safe without the lock because asyncio is single-threaded between `await` points) and unconditionally call `self._current_batch_submitted.set()` in the `finally` block.

How can I resolve this? If you propose a fix, please make it concise.

Comment on lines +236 to +258
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 KeyError (not CancelledError) when task is cancelled before lock acquisition

batch_ids is initialised to [] before the try block. If _current_batch_task.cancel() fires while _run_batch is suspended inside asyncio.wait(...), or while it is waiting to acquire _batch_lock (another _batch_get_embeddings call holds it briefly), the except asyncio.CancelledError handler iterates over the empty list and writes nothing to _req_errors. batch_finished_event.set() still fires in finally, waking every caller. Those callers then find no entry in _req_errors and fall through to self._req_results.pop(req_id), which raises KeyError.

A secondary consequence of the same scenario: the lock section never runs, so _current_batch_submitted is never set and _req_queue is never cleared. Any overflow caller blocked on await batch_submitted_event.wait() is permanently stuck, and future callers joining the still-non-None _current_batch_finished_event would also hit the same KeyError.

The fix requires either (a) snapshotting _req_queue keys before entering asyncio.wait so they are available to the early-cancel path, or (b) guarding self._req_results.pop(req_id) in _batch_get_embeddings to raise a typed error when the key is absent after the event fires.

Prompt To Fix With AI
This is a comment left during a code review.
Path: nemoguardrails/embeddings/basic.py
Line: 236-246

Comment:
**`KeyError` (not `CancelledError`) when task is cancelled before lock acquisition**

`batch_ids` is initialised to `[]` before the `try` block. If `_current_batch_task.cancel()` fires while `_run_batch` is suspended inside `asyncio.wait(...)`, or while it is waiting to acquire `_batch_lock` (another `_batch_get_embeddings` call holds it briefly), the `except asyncio.CancelledError` handler iterates over the empty list and writes nothing to `_req_errors`. `batch_finished_event.set()` still fires in `finally`, waking every caller. Those callers then find no entry in `_req_errors` and fall through to `self._req_results.pop(req_id)`, which raises `KeyError`.

A secondary consequence of the same scenario: the lock section never runs, so `_current_batch_submitted` is never set and `_req_queue` is never cleared. Any overflow caller blocked on `await batch_submitted_event.wait()` is permanently stuck, and future callers joining the still-non-None `_current_batch_finished_event` would also hit the same `KeyError`.

The fix requires either (a) snapshotting `_req_queue` keys before entering `asyncio.wait` so they are available to the early-cancel path, or (b) guarding `self._req_results.pop(req_id)` in `_batch_get_embeddings` to raise a typed error when the key is absent after the event fires.

How can I resolve this? If you propose a fix, please make it concise.


async def _batch_get_embeddings(self, text: str) -> List[float]:
# As long as the queue is full, we wait for the next batch
while len(self._req_queue) >= self.max_batch_size:
await self._current_batch_submitted.wait()

req_id = self._req_idx
self._req_idx += 1
self._req_queue[req_id] = text

if self._current_batch_finished_event is None or self._current_batch_full_event is None:
self._current_batch_finished_event = asyncio.Event()
self._current_batch_full_event = asyncio.Event()
self._current_batch_submitted.clear()
asyncio.ensure_future(self._run_batch())

# We check if we reached the max batch size
if len(self._req_queue) >= self.max_batch_size:
self._current_batch_full_event.set()

# Wait for the batch to finish
await self._current_batch_finished_event.wait()

# Remove the result and return it
result = self._req_results[req_id]
del self._req_results[req_id]
while True:
async with self._batch_lock:
if len(self._req_queue) < self.max_batch_size:
req_id = self._req_idx
self._req_idx += 1
self._req_queue[req_id] = text

if self._current_batch_finished_event is None or self._current_batch_full_event is None:
self._current_batch_finished_event = asyncio.Event()
self._current_batch_full_event = asyncio.Event()
self._current_batch_submitted.clear()
self._current_batch_task = asyncio.create_task(
self._run_batch(
self._current_batch_full_event,
self._current_batch_finished_event,
)
)

batch_finished_event = self._current_batch_finished_event
batch_full_event = self._current_batch_full_event
if batch_finished_event is None or batch_full_event is None:
raise RuntimeError("Batch events not initialized. This should not happen.")

# We check if we reached the max batch size
if len(self._req_queue) >= self.max_batch_size:
batch_full_event.set()

break

batch_submitted_event = self._current_batch_submitted

await batch_submitted_event.wait()
Comment on lines +261 to +292
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject non-positive max_batch_size.

With max_batch_size <= 0, Line 251 is never true, no batch task is ever created, and every caller blocks forever on Line 280. This should fail fast with a ValueError during initialization or before entering the loop.

Minimal guard
         self.use_batching = use_batching
+        if max_batch_size <= 0:
+            raise ValueError("max_batch_size must be greater than 0")
         self.max_batch_size = max_batch_size
         self.max_batch_hold = max_batch_hold
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@nemoguardrails/embeddings/basic.py` around lines 249 - 280, The code can
deadlock when max_batch_size <= 0 because the enqueue loop (uses
self._batch_lock, self._req_queue, self._current_batch_task, and calls
self._run_batch) never creates a batch; add a guard that validates
max_batch_size is a positive integer and raise ValueError early (either in the
class __init__ or at the start of the enqueue method that contains this loop) so
callers fail fast; ensure the check references self.max_batch_size and prevents
entering the while True block when invalid.


# Wait for the batch to finish; clean up our slot regardless of how we exit.
try:
await batch_finished_event.wait()

if req_id in self._req_errors:
raise self._req_errors.pop(req_id)

if req_id not in self._req_results:
raise RuntimeError(f"Batch completed without a result for request {req_id}.")
result = self._req_results.pop(req_id)
finally:
self._req_results.pop(req_id, None)
self._req_errors.pop(req_id, None)

return result

Expand Down
125 changes: 125 additions & 0 deletions tests/test_batch_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,131 @@
from nemoguardrails.embeddings.index import IndexItem


class MockEmbeddingModel:
def __init__(self):
self.call_count = 0

async def encode_async(self, texts):
self.call_count += 1
await asyncio.sleep(0.01)
return [[float(text.split()[-1])] for text in texts]


class FailingEmbeddingModel:
async def encode_async(self, texts):
raise RuntimeError("embedding model failure")


class ShortEmbeddingModel:
"""Returns one fewer embedding than requested."""

async def encode_async(self, texts):
return [[float(i)] for i in range(len(texts) - 1)]


@pytest.mark.asyncio
async def test_batch_get_embeddings_propagates_short_result():
embeddings_index = BasicEmbeddingsIndex(
use_batching=True,
max_batch_size=4,
max_batch_hold=0.01,
)
embeddings_index._model = ShortEmbeddingModel()

with pytest.raises(RuntimeError, match="Embedding model returned"):
await asyncio.wait_for(
asyncio.gather(
embeddings_index._batch_get_embeddings("text 0"),
embeddings_index._batch_get_embeddings("text 1"),
),
timeout=1,
)


@pytest.mark.asyncio
async def test_batch_get_embeddings_propagates_cancelled_batch_task():
"""Cancelling the active _run_batch task must wake callers with CancelledError."""
encoding_started = asyncio.Event()

class HangingModel:
async def encode_async(self, texts):
encoding_started.set()
await asyncio.sleep(10)

embeddings_index = BasicEmbeddingsIndex(
use_batching=True,
max_batch_size=10,
max_batch_hold=0.001,
)
embeddings_index._model = HangingModel()

caller = asyncio.create_task(embeddings_index._batch_get_embeddings("text 0"))
# Wait until _get_embeddings has been entered so cancellation hits that await.
await asyncio.wait_for(encoding_started.wait(), timeout=1)
embeddings_index._current_batch_task.cancel()

with pytest.raises(asyncio.CancelledError):
await caller


@pytest.mark.asyncio
async def test_batch_get_embeddings_propagates_early_cancellation():
"""Cancelling _run_batch before batch_ids is populated must still wake callers."""
embeddings_index = BasicEmbeddingsIndex(
use_batching=True,
max_batch_size=10,
max_batch_hold=10, # Long hold so cancellation fires during asyncio.wait
)
embeddings_index._model = MockEmbeddingModel()

caller = asyncio.create_task(embeddings_index._batch_get_embeddings("text 0"))
# Yield to let _batch_get_embeddings register the request and start _run_batch,
# then cancel before max_batch_hold elapses (i.e. before the lock section runs).
await asyncio.sleep(0)
await asyncio.sleep(0)
assert embeddings_index._current_batch_task is not None
embeddings_index._current_batch_task.cancel()

with pytest.raises(asyncio.CancelledError):
await asyncio.wait_for(caller, timeout=1)


@pytest.mark.asyncio
async def test_batch_get_embeddings_propagates_model_error():
embeddings_index = BasicEmbeddingsIndex(
use_batching=True,
max_batch_size=2,
max_batch_hold=0.01,
)
embeddings_index._model = FailingEmbeddingModel()

with pytest.raises(RuntimeError, match="embedding model failure"):
await asyncio.wait_for(
embeddings_index._batch_get_embeddings("text 0"),
timeout=1,
)


@pytest.mark.asyncio
async def test_batch_get_embeddings_handles_concurrent_batches():
mock_model = MockEmbeddingModel()
embeddings_index = BasicEmbeddingsIndex(
use_batching=True,
max_batch_size=2,
max_batch_hold=0.01,
)
embeddings_index._model = mock_model

results = await asyncio.wait_for(
asyncio.gather(*(embeddings_index._batch_get_embeddings(f"text {i}") for i in range(5))),
timeout=1,
)

assert sorted(result[0] for result in results) == [0, 1, 2, 3, 4]
# 5 requests with max_batch_size=2 must produce at least 3 separate batches
assert mock_model.call_count >= 3


@pytest.mark.skip(reason="Run manually.")
@pytest.mark.asyncio
async def test_search_speed():
Expand Down
Loading