fix: synchronize embedding batches#1921
Conversation
Documentation preview |
📝 WalkthroughWalkthroughRefactors BasicEmbeddingsIndex batching to add per-request error propagation, an async lock and batch-task coordination, event-driven batch submission, tests for concurrency/error cases, and a changelog entry documenting the deadlock fix. ChangesEmbeddings batching concurrency and deadlock fix
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
nemoguardrails/embeddings/basic.py (1)
197-197: 💤 Low valueUnused variable
done.The
donevariable fromasyncio.waitis never used. Prefix with underscore to indicate intentional discard.- done, pending = await asyncio.wait( + _, pending = await asyncio.wait(🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@nemoguardrails/embeddings/basic.py` at line 197, The variable tuple unpacking from asyncio.wait currently assigns to done and pending but done is never used; change the unpack to _done, pending (or prefix done with an underscore) where the call to asyncio.wait occurs so the unused result is explicitly discarded (look for the asyncio.wait call in embeddings/basic.py and the surrounding coroutine function).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@nemoguardrails/embeddings/basic.py`:
- Around line 223-231: The try block that processes embeddings from
_get_embeddings can leave some batch_ids without any entry when the returned
embeddings list is shorter than batch_ids; add a defensive length check after
awaiting self._get_embeddings(batch) in the same block: if len(embeddings) <
len(batch_ids) then for the missing ids (batch_ids[len(embeddings):]) populate
self._req_errors[req_id] with a clear exception (e.g., ValueError or
RuntimeError) describing "embedding model returned fewer items than requested"
before continuing to set any available results into self._req_results; keep the
existing exception handler and ensure batch_finished_event.set() still runs in
finally.
---
Nitpick comments:
In `@nemoguardrails/embeddings/basic.py`:
- Line 197: The variable tuple unpacking from asyncio.wait currently assigns to
done and pending but done is never used; change the unpack to _done, pending (or
prefix done with an underscore) where the call to asyncio.wait occurs so the
unused result is explicitly discarded (look for the asyncio.wait call in
embeddings/basic.py and the surrounding coroutine function).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 1b69d4ac-283d-43c7-b47e-714a05be1039
📒 Files selected for processing (3)
CHANGELOG.mdnemoguardrails/embeddings/basic.pytests/test_batch_embeddings.py
27043bb to
b14c2a5
Compare
Greptile SummaryThis PR rewrites the batched embedding concurrency model in
|
| Filename | Overview |
|---|---|
| nemoguardrails/embeddings/basic.py | Core batching logic overhauled: added _batch_lock, per-request error propagation, explicit cancellation handling, and event-passing to _run_batch. Logic is correct for the happy path and the common cancellation cases; two minor concerns around inner task leaks on Python < 3.12 and shared exception instance traceback mutation remain. |
| tests/test_batch_embeddings.py | Adds five async tests covering model failure, short result, early/late cancellation propagation, and concurrent batch dispatch. Coverage is well-targeted at the newly introduced code paths. |
| CHANGELOG.md | Adds an [Unreleased] entry noting the batching synchronization fix; no issues. |
Sequence Diagram
sequenceDiagram
participant CA as Caller A
participant CB as Caller B
participant Lock as _batch_lock
participant RB as _run_batch task
participant Model as EmbeddingModel
CA->>Lock: acquire
CA->>Lock: "enqueue req_id=0, create batch events & task"
CA->>Lock: release (captures batch_finished_event)
CB->>Lock: acquire
CB->>Lock: "enqueue req_id=1, join same batch"
CB->>Lock: release (captures same batch_finished_event)
Note over CA,CB: Both await batch_finished_event.wait()
RB->>RB: asyncio.wait(sleep OR batch_full_event)
RB->>Lock: acquire (snapshot queue, clear state, set _batch_submitted)
RB->>Lock: release
RB->>Model: _get_embeddings([text0, text1])
Model-->>RB: [[e0], [e1]]
RB->>RB: "_req_results[0]=[e0], _req_results[1]=[e1]"
RB->>RB: batch_finished_event.set()
CA-->>CA: wake, pop _req_results[0], return [e0]
CB-->>CB: wake, pop _req_results[1], return [e1]
Prompt To Fix All With AI
Fix the following 3 code review issues. Work through them one at a time, proposing concise fixes.
---
### Issue 1 of 3
nemoguardrails/embeddings/basic.py:201-209
**Leaked inner tasks when `_run_batch` is cancelled on Python < 3.12**
When `_run_batch`'s outer task is cancelled while suspended inside `asyncio.wait(...)`, Python 3.8–3.11 does not automatically cancel the two inner tasks. The cleanup loop `for task in pending: task.cancel()` is never reached, so `asyncio.create_task(batch_full_event.wait())` keeps running. Because `_current_batch_full_event` is set to `None` in the cancellation handler, that local `batch_full_event` object is only referenced by the leaked task and will never be set — making the `wait()` task effectively permanent until the event loop closes. Under sustained load with occasional `cancel()` calls this accumulates quietly. Python 3.12 added automatic cancellation of `asyncio.wait` sub-tasks on cancellation; if < 3.12 must be supported, the inner tasks should be cancelled explicitly in the `except asyncio.CancelledError` handler before `raise`.
### Issue 2 of 3
nemoguardrails/embeddings/basic.py:246-249
**Shared `CancelledError` instance mutates `__traceback__` across callers**
The same `exc` object is stored in `_req_errors` for every `req_id`. When several callers each execute `raise self._req_errors.pop(req_id)`, every `raise` mutates `exc.__traceback__` in place. The last raise wins, so all earlier call-site tracebacks are silently overwritten, making post-mortem debugging unreliable. Storing a distinct copy per request avoids this.
```suggestion
for req_id in batch_ids:
if req_id not in self._req_results and req_id not in self._req_errors:
self._req_errors[req_id] = asyncio.CancelledError(*exc.args)
raise
```
### Issue 3 of 3
nemoguardrails/embeddings/basic.py:93-94
**`_current_batch_task` only tracks the most-recently created batch**
Overlapping batches are possible: once `_run_batch` releases `_batch_lock` after draining the queue, new callers may create a second batch before the first `_get_embeddings` call returns. Each new batch overwrites `_current_batch_task`, so earlier concurrent batches are no longer reachable via this attribute. Any consumer relying on `_current_batch_task.cancel()` to halt all in-flight work (e.g. during shutdown) will silently miss earlier batches. The comment "Stored so callers can cancel… active batch task" overpromises — consider renaming to `_last_batch_task` or documenting the single-task limitation explicitly.
Reviews (7): Last reviewed commit: "fix: raise RuntimeError instead of KeyEr..." | Re-trigger Greptile
| except BaseException as exc: | ||
| for req_id in batch_ids: | ||
| if req_id not in self._req_results and req_id not in self._req_errors: | ||
| self._req_errors[req_id] = exc | ||
| finally: | ||
| batch_finished_event.set() |
There was a problem hiding this comment.
BaseException catch swallows CancelledError
Catching BaseException includes asyncio.CancelledError (a BaseException subclass in Python 3.8+). If _run_batch's underlying task is externally cancelled while awaiting _get_embeddings, the cancellation is absorbed rather than propagated, so self._current_batch_task.cancel() will appear to succeed but the task will actually run to completion. The fix is to catch only Exception (plus a separate handler for CancelledError that re-raises after storing the error), or at minimum re-raise CancelledError after recording it in _req_errors, relying on the finally block to still set batch_finished_event.
Prompt To Fix With AI
This is a comment left during a code review.
Path: nemoguardrails/embeddings/basic.py
Line: 233-238
Comment:
**`BaseException` catch swallows `CancelledError`**
Catching `BaseException` includes `asyncio.CancelledError` (a `BaseException` subclass in Python 3.8+). If `_run_batch`'s underlying task is externally cancelled while awaiting `_get_embeddings`, the cancellation is absorbed rather than propagated, so `self._current_batch_task.cancel()` will appear to succeed but the task will actually run to completion. The fix is to catch only `Exception` (plus a separate handler for `CancelledError` that re-raises after storing the error), or at minimum re-raise `CancelledError` after recording it in `_req_errors`, relying on the `finally` block to still set `batch_finished_event`.
How can I resolve this? If you propose a fix, please make it concise.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
@coderabbitai review |
|
@greptile-apps review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
🧹 Nitpick comments (1)
nemoguardrails/embeddings/basic.py (1)
257-262: 💤 Low valueRemove or actually use
self._current_batch_task
nemoguardrails/embeddings/basic.pyassignsself._current_batch_task = asyncio.create_task(...), but_current_batch_taskis never awaited/cancelled/referenced anywhere else (only declared and assigned). If it’s not meant to be managed externally, remove the field and create the task without storing it; otherwise document the intended lifecycle/ownership.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@nemoguardrails/embeddings/basic.py` around lines 257 - 262, The assignment to self._current_batch_task creates an asyncio.Task for self._run_batch but the attribute is never awaited, cancelled, or referenced; either remove self._current_batch_task and call asyncio.create_task(...) without storing the result, or explicitly manage the task lifecycle: store it as self._current_batch_task, document ownership, and ensure you await or cancel it (e.g., in shutdown/cleanup) and check its status where appropriate; the change should be applied around the creation call that passes _current_batch_full_event and _current_batch_finished_event to _run_batch so the code no longer leaks an unmanaged task or leaves an unused attribute.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@nemoguardrails/embeddings/basic.py`:
- Around line 257-262: The assignment to self._current_batch_task creates an
asyncio.Task for self._run_batch but the attribute is never awaited, cancelled,
or referenced; either remove self._current_batch_task and call
asyncio.create_task(...) without storing the result, or explicitly manage the
task lifecycle: store it as self._current_batch_task, document ownership, and
ensure you await or cancel it (e.g., in shutdown/cleanup) and check its status
where appropriate; the change should be applied around the creation call that
passes _current_batch_full_event and _current_batch_finished_event to _run_batch
so the code no longer leaks an unmanaged task or leaves an unused attribute.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 30f58987-fdef-4e34-aa0b-77cc4476dec8
📒 Files selected for processing (2)
nemoguardrails/embeddings/basic.pytests/test_batch_embeddings.py
Signed-off-by: Pradyoth P <pradyoth0@gmail.com>
- Wrap _get_embeddings call in try/except/finally so batch_finished_event is always set even on exception; store per-req errors in _req_errors so callers observe the failure instead of hanging indefinitely - Replace asyncio.ensure_future with asyncio.get_event_loop().create_task (ensure_future deprecated since Python 3.10) - Add call_count to MockEmbeddingModel and assert at least 3 batches are dispatched for 5 requests with max_batch_size=2 - Apply ruff formatting across both changed files
Adds test_batch_get_embeddings_propagates_model_error to exercise the except/finally branches added to _run_batch and the re-raise in _batch_get_embeddings, bringing codecov patch coverage to 100%.
Replaces asyncio.get_event_loop().create_task() with asyncio.create_task() (get_event_loop() deprecated in Python 3.10+) and stores the returned Task in self._current_batch_task so it is not silently discarded.
- Rename unused done from asyncio.wait to _ - After _get_embeddings returns, check that the number of embeddings matches batch_ids; populate _req_errors for any missing slots so callers always receive either a result or an exception
- Split BaseException handler into CancelledError (stores error, re-raises so task cancellation propagates correctly) and Exception (stores error without re-raising); batch_finished_event is still set via finally in both cases - Wrap the wait-and-retrieve block in _batch_get_embeddings with try/finally so _req_results and _req_errors entries are always cleaned up even when the caller is cancelled mid-wait
- Fix test_batch_get_embeddings_propagates_short_result: match pattern 'fewer embeddings' did not appear in the actual error message; corrected to 'Embedding model returned' (CI was failing on all Python versions) - Move try/finally to wrap the full _run_batch body (including the initial asyncio.wait and lock section) so batch_finished_event is guaranteed to be set regardless of where a cancellation or error fires; batch_ids initialised to [] before the try so except handlers never hit NameError - Add comment to _current_batch_task explaining it is stored for lifecycle control (e.g. cancellation, shutdown) - Add test_batch_get_embeddings_propagates_cancelled_batch_task: signals when encode_async starts, cancels the batch task at that point, and asserts CancelledError propagates to the caller; covers the previously uncovered except asyncio.CancelledError block
10dd63d to
8b59599
Compare
|
@coderabbitai review |
|
@greptile-apps review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@nemoguardrails/embeddings/basic.py`:
- Around line 249-280: The code can deadlock when max_batch_size <= 0 because
the enqueue loop (uses self._batch_lock, self._req_queue,
self._current_batch_task, and calls self._run_batch) never creates a batch; add
a guard that validates max_batch_size is a positive integer and raise ValueError
early (either in the class __init__ or at the start of the enqueue method that
contains this loop) so callers fail fast; ensure the check references
self.max_batch_size and prevents entering the while True block when invalid.
- Around line 196-246: The cancellation can occur before batch_ids is populated,
leaving batch_finished_event set and the queue still pointing to the cancelled
batch; to fix, when catching asyncio.CancelledError (or immediately after
detecting cancellation) acquire self._batch_lock and snapshot the queued ids for
this batch (e.g., read keys from self._req_queue into batch_ids and clear those
entries) and then mark each of those ids in self._req_errors with the
CancelledError before re-raising; ensure you also clear/restore
_current_batch_finished_event/_current_batch_full_event and set
_current_batch_submitted appropriately so the internal batch state no longer
references the cancelled batch (use symbols: self._batch_lock, batch_ids,
self._req_queue, self._req_errors, batch_finished_event,
self._current_batch_finished_event, self._current_batch_full_event,
self._current_batch_submitted).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: af0d035e-acfd-4ed3-8c36-9a843ee230da
📒 Files selected for processing (3)
CHANGELOG.mdnemoguardrails/embeddings/basic.pytests/test_batch_embeddings.py
| while True: | ||
| async with self._batch_lock: | ||
| if len(self._req_queue) < self.max_batch_size: | ||
| req_id = self._req_idx | ||
| self._req_idx += 1 | ||
| self._req_queue[req_id] = text | ||
|
|
||
| if self._current_batch_finished_event is None or self._current_batch_full_event is None: | ||
| self._current_batch_finished_event = asyncio.Event() | ||
| self._current_batch_full_event = asyncio.Event() | ||
| self._current_batch_submitted.clear() | ||
| self._current_batch_task = asyncio.create_task( | ||
| self._run_batch( | ||
| self._current_batch_full_event, | ||
| self._current_batch_finished_event, | ||
| ) | ||
| ) | ||
|
|
||
| batch_finished_event = self._current_batch_finished_event | ||
| batch_full_event = self._current_batch_full_event | ||
| if batch_finished_event is None or batch_full_event is None: | ||
| raise RuntimeError("Batch events not initialized. This should not happen.") | ||
|
|
||
| # We check if we reached the max batch size | ||
| if len(self._req_queue) >= self.max_batch_size: | ||
| batch_full_event.set() | ||
|
|
||
| break | ||
|
|
||
| batch_submitted_event = self._current_batch_submitted | ||
|
|
||
| await batch_submitted_event.wait() |
There was a problem hiding this comment.
Reject non-positive max_batch_size.
With max_batch_size <= 0, Line 251 is never true, no batch task is ever created, and every caller blocks forever on Line 280. This should fail fast with a ValueError during initialization or before entering the loop.
Minimal guard
self.use_batching = use_batching
+ if max_batch_size <= 0:
+ raise ValueError("max_batch_size must be greater than 0")
self.max_batch_size = max_batch_size
self.max_batch_hold = max_batch_hold🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@nemoguardrails/embeddings/basic.py` around lines 249 - 280, The code can
deadlock when max_batch_size <= 0 because the enqueue loop (uses
self._batch_lock, self._req_queue, self._current_batch_task, and calls
self._run_batch) never creates a batch; add a guard that validates
max_batch_size is a positive integer and raise ValueError early (either in the
class __init__ or at the start of the enqueue method that contains this loop) so
callers fail fast; ensure the check references self.max_batch_size and prevents
entering the while True block when invalid.
| except asyncio.CancelledError as exc: | ||
| for req_id in batch_ids: | ||
| if req_id not in self._req_results and req_id not in self._req_errors: | ||
| self._req_errors[req_id] = exc | ||
| raise | ||
| except Exception as exc: | ||
| for req_id in batch_ids: | ||
| if req_id not in self._req_results and req_id not in self._req_errors: | ||
| self._req_errors[req_id] = exc | ||
| finally: | ||
| batch_finished_event.set() |
There was a problem hiding this comment.
Early cancellation leaves callers with
KeyError and blocks full-queue waiters
batch_ids is initialised as [] on line 198 and is only populated inside the async with self._batch_lock section (line 218). If the outer task is cancelled while awaiting asyncio.wait([sleep, batch_full_event.wait()]) — e.g. via _current_batch_task.cancel() during shutdown before the hold timer expires — CancelledError is thrown before the lock section ever runs. The except asyncio.CancelledError handler then iterates an empty batch_ids, so no error is stored for any of the requests that are still in _req_queue. finally sets batch_finished_event, those callers wake up, find nothing in _req_results or _req_errors, and crash with KeyError from self._req_results.pop(req_id).
A second failure follows: _current_batch_submitted.set() is only called inside the lock section. Any caller blocked on await batch_submitted_event.wait() because the queue was at capacity will deadlock — nothing will ever set that event again, and the new-batch creation path is also stuck because _current_batch_finished_event and _current_batch_full_event are left pointing to the already-signalled, stale events from the cancelled batch.
The existing test test_batch_get_embeddings_propagates_cancelled_batch_task avoids this path by waiting for encoding_started before cancelling, guaranteeing the lock section has already run. A minimal fix is to handle remaining queue entries in the except handlers (safe without the lock because asyncio is single-threaded between await points) and unconditionally call self._current_batch_submitted.set() in the finally block.
Prompt To Fix With AI
This is a comment left during a code review.
Path: nemoguardrails/embeddings/basic.py
Line: 236-246
Comment:
**Early cancellation leaves callers with `KeyError` and blocks full-queue waiters**
`batch_ids` is initialised as `[]` on line 198 and is only populated inside the `async with self._batch_lock` section (line 218). If the outer task is cancelled while awaiting `asyncio.wait([sleep, batch_full_event.wait()])` — e.g. via `_current_batch_task.cancel()` during shutdown before the hold timer expires — `CancelledError` is thrown before the lock section ever runs. The `except asyncio.CancelledError` handler then iterates an empty `batch_ids`, so no error is stored for any of the requests that are still in `_req_queue`. `finally` sets `batch_finished_event`, those callers wake up, find nothing in `_req_results` or `_req_errors`, and crash with `KeyError` from `self._req_results.pop(req_id)`.
A second failure follows: `_current_batch_submitted.set()` is only called inside the lock section. Any caller blocked on `await batch_submitted_event.wait()` because the queue was at capacity will deadlock — nothing will ever set that event again, and the new-batch creation path is also stuck because `_current_batch_finished_event` and `_current_batch_full_event` are left pointing to the already-signalled, stale events from the cancelled batch.
The existing test `test_batch_get_embeddings_propagates_cancelled_batch_task` avoids this path by waiting for `encoding_started` before cancelling, guaranteeing the lock section has already run. A minimal fix is to handle remaining queue entries in the except handlers (safe without the lock because asyncio is single-threaded between `await` points) and unconditionally call `self._current_batch_submitted.set()` in the `finally` block.
How can I resolve this? If you propose a fix, please make it concise.| except asyncio.CancelledError as exc: | ||
| for req_id in batch_ids: | ||
| if req_id not in self._req_results and req_id not in self._req_errors: | ||
| self._req_errors[req_id] = exc | ||
| raise | ||
| except Exception as exc: | ||
| for req_id in batch_ids: | ||
| if req_id not in self._req_results and req_id not in self._req_errors: | ||
| self._req_errors[req_id] = exc | ||
| finally: | ||
| batch_finished_event.set() |
There was a problem hiding this comment.
KeyError (not CancelledError) when task is cancelled before lock acquisition
batch_ids is initialised to [] before the try block. If _current_batch_task.cancel() fires while _run_batch is suspended inside asyncio.wait(...), or while it is waiting to acquire _batch_lock (another _batch_get_embeddings call holds it briefly), the except asyncio.CancelledError handler iterates over the empty list and writes nothing to _req_errors. batch_finished_event.set() still fires in finally, waking every caller. Those callers then find no entry in _req_errors and fall through to self._req_results.pop(req_id), which raises KeyError.
A secondary consequence of the same scenario: the lock section never runs, so _current_batch_submitted is never set and _req_queue is never cleared. Any overflow caller blocked on await batch_submitted_event.wait() is permanently stuck, and future callers joining the still-non-None _current_batch_finished_event would also hit the same KeyError.
The fix requires either (a) snapshotting _req_queue keys before entering asyncio.wait so they are available to the early-cancel path, or (b) guarding self._req_results.pop(req_id) in _batch_get_embeddings to raise a typed error when the key is absent after the event fires.
Prompt To Fix With AI
This is a comment left during a code review.
Path: nemoguardrails/embeddings/basic.py
Line: 236-246
Comment:
**`KeyError` (not `CancelledError`) when task is cancelled before lock acquisition**
`batch_ids` is initialised to `[]` before the `try` block. If `_current_batch_task.cancel()` fires while `_run_batch` is suspended inside `asyncio.wait(...)`, or while it is waiting to acquire `_batch_lock` (another `_batch_get_embeddings` call holds it briefly), the `except asyncio.CancelledError` handler iterates over the empty list and writes nothing to `_req_errors`. `batch_finished_event.set()` still fires in `finally`, waking every caller. Those callers then find no entry in `_req_errors` and fall through to `self._req_results.pop(req_id)`, which raises `KeyError`.
A secondary consequence of the same scenario: the lock section never runs, so `_current_batch_submitted` is never set and `_req_queue` is never cleared. Any overflow caller blocked on `await batch_submitted_event.wait()` is permanently stuck, and future callers joining the still-non-None `_current_batch_finished_event` would also hit the same `KeyError`.
The fix requires either (a) snapshotting `_req_queue` keys before entering `asyncio.wait` so they are available to the early-cancel path, or (b) guarding `self._req_results.pop(req_id)` in `_batch_get_embeddings` to raise a typed error when the key is absent after the event fires.
How can I resolve this? If you propose a fix, please make it concise.…lated If _run_batch is cancelled during the asyncio.wait (before acquiring the lock), batch_ids remains empty. Drain _req_queue without the lock — safe because asyncio is single-threaded and no await occurs before `raise`. Also unconditionally call _current_batch_submitted.set() in finally so queue-full waiters are never left blocked. Add test_batch_get_embeddings_propagates_early_cancellation to cover this path explicitly.
Guard `_req_results.pop(req_id)` in `_batch_get_embeddings` so that if a req_id is absent from both `_req_results` and `_req_errors` when the batch event fires, callers get a clear RuntimeError rather than a confusing KeyError. Paired with the early-cancellation fix that ensures all in-flight request IDs are written to `_req_errors` before the event is set.
|
@coderabbitai review |
|
@greptile-apps review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/test_batch_embeddings.py (2)
202-202: 💤 Low valueConsider using
asyncio.create_taskfor consistency.The PR replaced
asyncio.ensure_futurewithasyncio.create_taskin the implementation. For consistency, this manual test could usecreate_taskas well, though the impact is minimal since the test is skipped.✨ Consistency fix
- task = asyncio.ensure_future(_search(f"This is a long sentence meant to mimic a user request {i}." * 5)) + task = asyncio.create_task(_search(f"This is a long sentence meant to mimic a user request {i}." * 5))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_batch_embeddings.py` at line 202, Replace the use of asyncio.ensure_future in the manual test with asyncio.create_task for consistency with the implementation; specifically, change the line creating task for calling _search (currently using asyncio.ensure_future(...)) to use asyncio.create_task(...) and keep the same argument string construction and assignment to task to preserve behavior.
105-106: ⚖️ Poor tradeoffClarify the need for double
asyncio.sleep(0)or use explicit synchronization.Two consecutive
asyncio.sleep(0)calls seem redundant and make the test's timing assumptions unclear. The earlier testtest_batch_get_embeddings_propagates_cancelled_batch_task(lines 68-89) uses an explicitencoding_startedevent to reliably synchronize cancellation timing.Consider either:
- Removing the duplicate sleep if one suffices, or
- Adding a comment explaining why two yields are necessary, or
- Introducing an explicit synchronization point (mock hook or event) to make the test more deterministic and less fragile across different execution environments.
💡 Example: explicit synchronization
batch_started = asyncio.Event() class StartSignalingModel: async def encode_async(self, texts): batch_started.set() # ... rest of logic # In test: # await batch_started.wait() # instead of sleep(0); sleep(0)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_batch_embeddings.py` around lines 105 - 106, Replace the unclear double-yield (two consecutive asyncio.sleep(0) calls) with deterministic synchronization: either remove the duplicate sleep if one yield is sufficient, or—preferably—add an asyncio.Event (e.g., batch_started) and modify the mock/model method (e.g., StartSignalingModel.encode_async or the mocked encode_async used by the test) to call batch_started.set() when a batch begins, then in the test await batch_started.wait() instead of the two sleeps; ensure the test function (the one containing the two asyncio.sleep(0) calls) uses the event to synchronize cancellation timing reliably.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/test_batch_embeddings.py`:
- Line 202: Replace the use of asyncio.ensure_future in the manual test with
asyncio.create_task for consistency with the implementation; specifically,
change the line creating task for calling _search (currently using
asyncio.ensure_future(...)) to use asyncio.create_task(...) and keep the same
argument string construction and assignment to task to preserve behavior.
- Around line 105-106: Replace the unclear double-yield (two consecutive
asyncio.sleep(0) calls) with deterministic synchronization: either remove the
duplicate sleep if one yield is sufficient, or—preferably—add an asyncio.Event
(e.g., batch_started) and modify the mock/model method (e.g.,
StartSignalingModel.encode_async or the mocked encode_async used by the test) to
call batch_started.set() when a batch begins, then in the test await
batch_started.wait() instead of the two sleeps; ensure the test function (the
one containing the two asyncio.sleep(0) calls) uses the event to synchronize
cancellation timing reliably.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: c01d3bc5-f8e1-4035-bad5-d76fbe040ccb
📒 Files selected for processing (2)
nemoguardrails/embeddings/basic.pytests/test_batch_embeddings.py
Summary
Fixes #1476.
This updates batched embedding requests so each caller joins a batch while holding a small async lock, captures the finished/full events for that specific batch, and has the batch runner snapshot and clear the queued requests under the same lock. That avoids callers awaiting a later batch event or seeing a partially initialized event pair under concurrent load.
Key changes:
_run_batchnow wraps_get_embeddingsintry/except/finallysobatch_finished_eventis always set — embedding model failures are stored per-request in_req_errorsand re-raised in_batch_get_embeddingsinstead of hanging callers foreverasyncio.ensure_futurewithasyncio.create_task(deprecated since Python 3.10); task reference stored inself._current_batch_taskcall_count >= 3)Testing
python3 -m black nemoguardrails/embeddings/basic.py tests/test_batch_embeddings.py.venv/bin/python -m pytest tests/test_batch_embeddings.py -qruff format --check nemoguardrails/embeddings/basic.py tests/test_batch_embeddings.pygit diff --checkSummary by CodeRabbit