-
Notifications
You must be signed in to change notification settings - Fork 704
fix: synchronize embedding batches #1921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
81e99b1
d1fbf6e
e358c9d
b13239f
3792fb9
35650d4
f90e588
8b59599
edb6464
56e002f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,10 +86,14 @@ def __init__( | |
| # Data structures for batching embedding requests | ||
| self._req_queue: Dict[int, str] = {} | ||
| self._req_results: Dict[int, List[float]] = {} | ||
| self._req_errors: Dict[int, BaseException] = {} | ||
| self._req_idx: int = 0 | ||
| self._current_batch_finished_event: Optional[asyncio.Event] = None | ||
| self._current_batch_full_event: Optional[asyncio.Event] = None | ||
| # Stored so callers can cancel or inspect the active batch task (e.g. shutdown). | ||
| self._current_batch_task: Optional[asyncio.Task] = None | ||
| self._current_batch_submitted: asyncio.Event = asyncio.Event() | ||
| self._batch_lock: asyncio.Lock = asyncio.Lock() | ||
|
|
||
| # Initialize the batching configuration | ||
| self.use_batching = use_batching | ||
|
|
@@ -187,74 +191,119 @@ async def build(self): | |
| self._index.add_item(i, self._embeddings[i]) | ||
| self._index.build(10) | ||
|
|
||
| async def _run_batch(self): | ||
| async def _run_batch(self, batch_full_event: asyncio.Event, batch_finished_event: asyncio.Event): | ||
| """Runs the current batch of embeddings.""" | ||
|
|
||
| # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. | ||
| if self._current_batch_full_event is None or self._current_batch_finished_event is None: | ||
| raise RuntimeError("Batch events not initialized. This should not happen.") | ||
|
|
||
| done, pending = await asyncio.wait( | ||
| [ | ||
| asyncio.create_task(asyncio.sleep(self.max_batch_hold)), | ||
| asyncio.create_task(self._current_batch_full_event.wait()), | ||
| ], | ||
| return_when=asyncio.FIRST_COMPLETED, | ||
| ) | ||
| for task in pending: | ||
| task.cancel() | ||
|
|
||
| # Reset the batch event | ||
| batch_event: asyncio.Event = self._current_batch_finished_event | ||
| self._current_batch_finished_event = None | ||
|
|
||
| # Create the actual batch to be send for computing | ||
| batch = [] | ||
| batch_ids = list(self._req_queue.keys()) | ||
| for req_id in batch_ids: | ||
| batch.append(self._req_queue[req_id]) | ||
|
|
||
| # Empty the queue up to this point | ||
| self._req_queue = {} | ||
|
|
||
| # We allow other batches to start | ||
| self._current_batch_submitted.set() | ||
|
|
||
| # print(f"Running batch of length {len(batch)}") | ||
|
|
||
| # Compute the embeddings | ||
| embeddings = await self._get_embeddings(batch) | ||
| for i in range(len(embeddings)): | ||
| self._req_results[batch_ids[i]] = embeddings[i] | ||
|
|
||
| # Signal that the batch has finished processing | ||
| batch_event.set() | ||
| # Initialised here so the except handlers can safely iterate even if | ||
| # cancellation fires before the lock section populates batch_ids. | ||
| batch_ids: List[int] = [] | ||
| try: | ||
| # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. | ||
| _, pending = await asyncio.wait( | ||
| [ | ||
| asyncio.create_task(asyncio.sleep(self.max_batch_hold)), | ||
| asyncio.create_task(batch_full_event.wait()), | ||
| ], | ||
| return_when=asyncio.FIRST_COMPLETED, | ||
| ) | ||
| for task in pending: | ||
| task.cancel() | ||
|
|
||
| async with self._batch_lock: | ||
| # Reset the active batch only if it has not already rolled over. | ||
| if self._current_batch_finished_event is batch_finished_event: | ||
| self._current_batch_finished_event = None | ||
| self._current_batch_full_event = None | ||
|
|
||
| # Create the actual batch to be sent for computing. | ||
| batch_ids = list(self._req_queue.keys()) | ||
| batch = [self._req_queue[req_id] for req_id in batch_ids] | ||
|
|
||
| # Empty the queue up to this point. | ||
| self._req_queue = {} | ||
|
|
||
| # We allow other batches to start. | ||
| self._current_batch_submitted.set() | ||
|
|
||
| embeddings = await self._get_embeddings(batch) | ||
| if len(embeddings) < len(batch_ids): | ||
| shortage_exc = RuntimeError( | ||
| f"Embedding model returned {len(embeddings)} embeddings for {len(batch_ids)} inputs." | ||
| ) | ||
| for req_id in batch_ids[len(embeddings) :]: | ||
| self._req_errors[req_id] = shortage_exc | ||
| for i in range(len(embeddings)): | ||
| self._req_results[batch_ids[i]] = embeddings[i] | ||
| except asyncio.CancelledError as exc: | ||
| # If cancelled before the lock section, batch_ids is still []. | ||
| # Snapshot and drain the queue without the lock — safe because asyncio | ||
| # is single-threaded and no await occurs between here and `raise`. | ||
| if not batch_ids: | ||
| batch_ids = list(self._req_queue.keys()) | ||
| self._req_queue = {} | ||
| if self._current_batch_finished_event is batch_finished_event: | ||
| self._current_batch_finished_event = None | ||
| self._current_batch_full_event = None | ||
| for req_id in batch_ids: | ||
| if req_id not in self._req_results and req_id not in self._req_errors: | ||
| self._req_errors[req_id] = exc | ||
| raise | ||
| except Exception as exc: | ||
| for req_id in batch_ids: | ||
| if req_id not in self._req_results and req_id not in self._req_errors: | ||
| self._req_errors[req_id] = exc | ||
| finally: | ||
| # Unconditionally unblock full-queue waiters in case the lock section | ||
| # was never reached (early cancellation). | ||
| self._current_batch_submitted.set() | ||
| batch_finished_event.set() | ||
|
coderabbitai[bot] marked this conversation as resolved.
Comment on lines
+236
to
+258
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
A second failure follows: The existing test Prompt To Fix With AIThis is a comment left during a code review.
Path: nemoguardrails/embeddings/basic.py
Line: 236-246
Comment:
**Early cancellation leaves callers with `KeyError` and blocks full-queue waiters**
`batch_ids` is initialised as `[]` on line 198 and is only populated inside the `async with self._batch_lock` section (line 218). If the outer task is cancelled while awaiting `asyncio.wait([sleep, batch_full_event.wait()])` — e.g. via `_current_batch_task.cancel()` during shutdown before the hold timer expires — `CancelledError` is thrown before the lock section ever runs. The `except asyncio.CancelledError` handler then iterates an empty `batch_ids`, so no error is stored for any of the requests that are still in `_req_queue`. `finally` sets `batch_finished_event`, those callers wake up, find nothing in `_req_results` or `_req_errors`, and crash with `KeyError` from `self._req_results.pop(req_id)`.
A second failure follows: `_current_batch_submitted.set()` is only called inside the lock section. Any caller blocked on `await batch_submitted_event.wait()` because the queue was at capacity will deadlock — nothing will ever set that event again, and the new-batch creation path is also stuck because `_current_batch_finished_event` and `_current_batch_full_event` are left pointing to the already-signalled, stale events from the cancelled batch.
The existing test `test_batch_get_embeddings_propagates_cancelled_batch_task` avoids this path by waiting for `encoding_started` before cancelling, guaranteeing the lock section has already run. A minimal fix is to handle remaining queue entries in the except handlers (safe without the lock because asyncio is single-threaded between `await` points) and unconditionally call `self._current_batch_submitted.set()` in the `finally` block.
How can I resolve this? If you propose a fix, please make it concise.
Comment on lines
+236
to
+258
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
A secondary consequence of the same scenario: the lock section never runs, so The fix requires either (a) snapshotting Prompt To Fix With AIThis is a comment left during a code review.
Path: nemoguardrails/embeddings/basic.py
Line: 236-246
Comment:
**`KeyError` (not `CancelledError`) when task is cancelled before lock acquisition**
`batch_ids` is initialised to `[]` before the `try` block. If `_current_batch_task.cancel()` fires while `_run_batch` is suspended inside `asyncio.wait(...)`, or while it is waiting to acquire `_batch_lock` (another `_batch_get_embeddings` call holds it briefly), the `except asyncio.CancelledError` handler iterates over the empty list and writes nothing to `_req_errors`. `batch_finished_event.set()` still fires in `finally`, waking every caller. Those callers then find no entry in `_req_errors` and fall through to `self._req_results.pop(req_id)`, which raises `KeyError`.
A secondary consequence of the same scenario: the lock section never runs, so `_current_batch_submitted` is never set and `_req_queue` is never cleared. Any overflow caller blocked on `await batch_submitted_event.wait()` is permanently stuck, and future callers joining the still-non-None `_current_batch_finished_event` would also hit the same `KeyError`.
The fix requires either (a) snapshotting `_req_queue` keys before entering `asyncio.wait` so they are available to the early-cancel path, or (b) guarding `self._req_results.pop(req_id)` in `_batch_get_embeddings` to raise a typed error when the key is absent after the event fires.
How can I resolve this? If you propose a fix, please make it concise. |
||
|
|
||
| async def _batch_get_embeddings(self, text: str) -> List[float]: | ||
| # As long as the queue is full, we wait for the next batch | ||
| while len(self._req_queue) >= self.max_batch_size: | ||
| await self._current_batch_submitted.wait() | ||
|
|
||
| req_id = self._req_idx | ||
| self._req_idx += 1 | ||
| self._req_queue[req_id] = text | ||
|
|
||
| if self._current_batch_finished_event is None or self._current_batch_full_event is None: | ||
| self._current_batch_finished_event = asyncio.Event() | ||
| self._current_batch_full_event = asyncio.Event() | ||
| self._current_batch_submitted.clear() | ||
| asyncio.ensure_future(self._run_batch()) | ||
|
|
||
| # We check if we reached the max batch size | ||
| if len(self._req_queue) >= self.max_batch_size: | ||
| self._current_batch_full_event.set() | ||
|
|
||
| # Wait for the batch to finish | ||
| await self._current_batch_finished_event.wait() | ||
|
|
||
| # Remove the result and return it | ||
| result = self._req_results[req_id] | ||
| del self._req_results[req_id] | ||
| while True: | ||
| async with self._batch_lock: | ||
| if len(self._req_queue) < self.max_batch_size: | ||
| req_id = self._req_idx | ||
| self._req_idx += 1 | ||
| self._req_queue[req_id] = text | ||
|
|
||
| if self._current_batch_finished_event is None or self._current_batch_full_event is None: | ||
| self._current_batch_finished_event = asyncio.Event() | ||
| self._current_batch_full_event = asyncio.Event() | ||
| self._current_batch_submitted.clear() | ||
| self._current_batch_task = asyncio.create_task( | ||
| self._run_batch( | ||
| self._current_batch_full_event, | ||
| self._current_batch_finished_event, | ||
| ) | ||
| ) | ||
|
|
||
| batch_finished_event = self._current_batch_finished_event | ||
| batch_full_event = self._current_batch_full_event | ||
| if batch_finished_event is None or batch_full_event is None: | ||
| raise RuntimeError("Batch events not initialized. This should not happen.") | ||
|
|
||
| # We check if we reached the max batch size | ||
| if len(self._req_queue) >= self.max_batch_size: | ||
| batch_full_event.set() | ||
|
|
||
| break | ||
|
|
||
| batch_submitted_event = self._current_batch_submitted | ||
|
|
||
| await batch_submitted_event.wait() | ||
|
Comment on lines
+261
to
+292
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reject non-positive With Minimal guard self.use_batching = use_batching
+ if max_batch_size <= 0:
+ raise ValueError("max_batch_size must be greater than 0")
self.max_batch_size = max_batch_size
self.max_batch_hold = max_batch_hold🤖 Prompt for AI Agents |
||
|
|
||
| # Wait for the batch to finish; clean up our slot regardless of how we exit. | ||
| try: | ||
| await batch_finished_event.wait() | ||
|
|
||
| if req_id in self._req_errors: | ||
| raise self._req_errors.pop(req_id) | ||
|
|
||
| if req_id not in self._req_results: | ||
| raise RuntimeError(f"Batch completed without a result for request {req_id}.") | ||
| result = self._req_results.pop(req_id) | ||
| finally: | ||
| self._req_results.pop(req_id, None) | ||
| self._req_errors.pop(req_id, None) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.