From 81e99b187d653abcd1f10ba041cc0779a1428d2d Mon Sep 17 00:00:00 2001 From: Pradyoth P Date: Fri, 22 May 2026 22:01:28 +0530 Subject: [PATCH 01/10] fix: synchronize embedding batches Signed-off-by: Pradyoth P --- CHANGELOG.md | 7 +- nemoguardrails/embeddings/basic.py | 114 ++++++++++++++++++----------- tests/test_batch_embeddings.py | 37 +++++++++- 3 files changed, 113 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 28d4e9454a..55baaa5fbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm > > The changes related to the Colang language and runtime have moved to [CHANGELOG-Colang](./CHANGELOG-Colang.md) file. +## [Unreleased] + +### 🐛 Bug Fixes + +- *(embeddings)* Synchronize batched embedding requests to avoid deadlocks ([#1476](https://github.com/NVIDIA-NeMo/Guardrails/issues/1476)) + ## [0.22.0] - 2026-05-22 ### 🚀 Features @@ -111,7 +117,6 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - Restore original 2023-2026 copyright dates on moved files ([#1831](https://github.com/NVIDIA-NeMo/Guardrails/issues/1831)) - Include scripts in docker image ([#1902](https://github.com/NVIDIA-NeMo/Guardrails/issues/1902)) - ## [0.21.0] - 2026-03-12 ### 🚀 Features diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 65ccf9ed11..9babca7ca0 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -90,6 +90,7 @@ def __init__( self._current_batch_finished_event: Optional[asyncio.Event] = None self._current_batch_full_event: Optional[asyncio.Event] = None self._current_batch_submitted: asyncio.Event = asyncio.Event() + self._batch_lock: asyncio.Lock = asyncio.Lock() # Initialize the batching configuration self.use_batching = use_batching @@ -175,7 +176,9 @@ async def add_items(self, items: List[IndexItem]): # If the index is already built, we skip this if self._index is None: - self._embeddings.extend(await self._get_embeddings([item.text for item in items])) + self._embeddings.extend( + await self._get_embeddings([item.text for item in items]) + ) # Update the embedding if it was not computed up to this point self._embedding_size = len(self._embeddings[0]) @@ -187,38 +190,37 @@ async def build(self): self._index.add_item(i, self._embeddings[i]) self._index.build(10) - async def _run_batch(self): + async def _run_batch( + self, batch_full_event: asyncio.Event, batch_finished_event: asyncio.Event + ): """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. - if self._current_batch_full_event is None or self._current_batch_finished_event is None: - raise RuntimeError("Batch events not initialized. This should not happen.") - done, pending = await asyncio.wait( [ asyncio.create_task(asyncio.sleep(self.max_batch_hold)), - asyncio.create_task(self._current_batch_full_event.wait()), + asyncio.create_task(batch_full_event.wait()), ], return_when=asyncio.FIRST_COMPLETED, ) for task in pending: task.cancel() - # Reset the batch event - batch_event: asyncio.Event = self._current_batch_finished_event - self._current_batch_finished_event = None + async with self._batch_lock: + # Reset the active batch only if it has not already rolled over. + if self._current_batch_finished_event is batch_finished_event: + self._current_batch_finished_event = None + self._current_batch_full_event = None - # Create the actual batch to be send for computing - batch = [] - batch_ids = list(self._req_queue.keys()) - for req_id in batch_ids: - batch.append(self._req_queue[req_id]) + # Create the actual batch to be sent for computing. + batch_ids = list(self._req_queue.keys()) + batch = [self._req_queue[req_id] for req_id in batch_ids] - # Empty the queue up to this point - self._req_queue = {} + # Empty the queue up to this point. + self._req_queue = {} - # We allow other batches to start - self._current_batch_submitted.set() + # We allow other batches to start. + self._current_batch_submitted.set() # print(f"Running batch of length {len(batch)}") @@ -228,29 +230,49 @@ async def _run_batch(self): self._req_results[batch_ids[i]] = embeddings[i] # Signal that the batch has finished processing - batch_event.set() + batch_finished_event.set() async def _batch_get_embeddings(self, text: str) -> List[float]: - # As long as the queue is full, we wait for the next batch - while len(self._req_queue) >= self.max_batch_size: - await self._current_batch_submitted.wait() - - req_id = self._req_idx - self._req_idx += 1 - self._req_queue[req_id] = text - - if self._current_batch_finished_event is None or self._current_batch_full_event is None: - self._current_batch_finished_event = asyncio.Event() - self._current_batch_full_event = asyncio.Event() - self._current_batch_submitted.clear() - asyncio.ensure_future(self._run_batch()) - - # We check if we reached the max batch size - if len(self._req_queue) >= self.max_batch_size: - self._current_batch_full_event.set() + while True: + async with self._batch_lock: + if len(self._req_queue) < self.max_batch_size: + req_id = self._req_idx + self._req_idx += 1 + self._req_queue[req_id] = text + + if ( + self._current_batch_finished_event is None + or self._current_batch_full_event is None + ): + self._current_batch_finished_event = asyncio.Event() + self._current_batch_full_event = asyncio.Event() + self._current_batch_submitted.clear() + asyncio.ensure_future( + self._run_batch( + self._current_batch_full_event, + self._current_batch_finished_event, + ) + ) + + batch_finished_event = self._current_batch_finished_event + batch_full_event = self._current_batch_full_event + if batch_finished_event is None or batch_full_event is None: + raise RuntimeError( + "Batch events not initialized. This should not happen." + ) + + # We check if we reached the max batch size + if len(self._req_queue) >= self.max_batch_size: + batch_full_event.set() + + break + + batch_submitted_event = self._current_batch_submitted + + await batch_submitted_event.wait() # Wait for the batch to finish - await self._current_batch_finished_event.wait() + await batch_finished_event.wait() # Remove the result and return it result = self._req_results[req_id] @@ -258,7 +280,9 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: return result - async def search(self, text: str, max_results: int = 20, threshold: Optional[float] = None) -> List[IndexItem]: + async def search( + self, text: str, max_results: int = 20, threshold: Optional[float] = None + ) -> List[IndexItem]: """Search the closest `max_results` items. Args: @@ -277,7 +301,9 @@ async def search(self, text: str, max_results: int = 20, threshold: Optional[flo _embedding = (await self._get_embeddings([text]))[0] if self._index is None: - raise ValueError("Index is not built yet. Ensure to call `build` before searching.") + raise ValueError( + "Index is not built yet. Ensure to call `build` before searching." + ) results = self._index.get_nns_by_vector( _embedding, @@ -298,8 +324,14 @@ async def search(self, text: str, max_results: int = 20, threshold: Optional[flo return [self._items[i] for i in filtered_results] @staticmethod - def _filter_results(indices: List[int], distances: List[float], threshold: float) -> List[int]: + def _filter_results( + indices: List[int], distances: List[float], threshold: float + ) -> List[int]: if threshold == float("inf"): return indices else: - return [index for index, distance in zip(indices, distances) if (1 - distance / 2) >= threshold] + return [ + index + for index, distance in zip(indices, distances) + if (1 - distance / 2) >= threshold + ] diff --git a/tests/test_batch_embeddings.py b/tests/test_batch_embeddings.py index b20d63130d..c037f8a52f 100644 --- a/tests/test_batch_embeddings.py +++ b/tests/test_batch_embeddings.py @@ -22,10 +22,37 @@ from nemoguardrails.embeddings.index import IndexItem +class MockEmbeddingModel: + async def encode_async(self, texts): + await asyncio.sleep(0.01) + return [[float(text.split()[-1])] for text in texts] + + +@pytest.mark.asyncio +async def test_batch_get_embeddings_handles_concurrent_batches(): + embeddings_index = BasicEmbeddingsIndex( + use_batching=True, + max_batch_size=2, + max_batch_hold=0.01, + ) + embeddings_index._model = MockEmbeddingModel() + + results = await asyncio.wait_for( + asyncio.gather( + *(embeddings_index._batch_get_embeddings(f"text {i}") for i in range(5)) + ), + timeout=1, + ) + + assert sorted(result[0] for result in results) == [0, 1, 2, 3, 4] + + @pytest.mark.skip(reason="Run manually.") @pytest.mark.asyncio async def test_search_speed(): - embeddings_index = BasicEmbeddingsIndex(embedding_model="all-MiniLM-L6-v2", embedding_engine="SentenceTransformers") + embeddings_index = BasicEmbeddingsIndex( + embedding_model="all-MiniLM-L6-v2", embedding_engine="SentenceTransformers" + ) # We compute an initial embedding, to warm up the model. await embeddings_index._get_embeddings(["warm up"]) @@ -74,7 +101,9 @@ async def _search(text): t0 = time() semaphore = asyncio.Semaphore(concurrency) for i in range(requests): - task = asyncio.ensure_future(_search(f"This is a long sentence meant to mimic a user request {i}." * 5)) + task = asyncio.ensure_future( + _search(f"This is a long sentence meant to mimic a user request {i}." * 5) + ) tasks.append(task) await asyncio.gather(*tasks) @@ -83,5 +112,7 @@ async def _search(text): print(f"Processing {completed_requests} took {took:0.2f}.") print(f"Completed {completed_requests} requests in {total_time:.2f} seconds.") - print(f"Average latency: {total_time / completed_requests if completed_requests else 0:.2f} seconds.") + print( + f"Average latency: {total_time / completed_requests if completed_requests else 0:.2f} seconds." + ) print(f"Maximum concurrency: {concurrency}") From d1fbf6e39c8af5245e2376c31335dff6192bbff3 Mon Sep 17 00:00:00 2001 From: ppradyoth <4ni19is062_a@nie.ac.in> Date: Fri, 22 May 2026 23:20:50 +0530 Subject: [PATCH 02/10] fix: handle embedding errors in _run_batch and improve batching - Wrap _get_embeddings call in try/except/finally so batch_finished_event is always set even on exception; store per-req errors in _req_errors so callers observe the failure instead of hanging indefinitely - Replace asyncio.ensure_future with asyncio.get_event_loop().create_task (ensure_future deprecated since Python 3.10) - Add call_count to MockEmbeddingModel and assert at least 3 batches are dispatched for 5 requests with max_batch_size=2 - Apply ruff formatting across both changed files --- nemoguardrails/embeddings/basic.py | 59 +++++++++++------------------- tests/test_batch_embeddings.py | 25 ++++++------- 2 files changed, 34 insertions(+), 50 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 9babca7ca0..e769d2c74f 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -86,6 +86,7 @@ def __init__( # Data structures for batching embedding requests self._req_queue: Dict[int, str] = {} self._req_results: Dict[int, List[float]] = {} + self._req_errors: Dict[int, BaseException] = {} self._req_idx: int = 0 self._current_batch_finished_event: Optional[asyncio.Event] = None self._current_batch_full_event: Optional[asyncio.Event] = None @@ -176,9 +177,7 @@ async def add_items(self, items: List[IndexItem]): # If the index is already built, we skip this if self._index is None: - self._embeddings.extend( - await self._get_embeddings([item.text for item in items]) - ) + self._embeddings.extend(await self._get_embeddings([item.text for item in items])) # Update the embedding if it was not computed up to this point self._embedding_size = len(self._embeddings[0]) @@ -190,9 +189,7 @@ async def build(self): self._index.add_item(i, self._embeddings[i]) self._index.build(10) - async def _run_batch( - self, batch_full_event: asyncio.Event, batch_finished_event: asyncio.Event - ): + async def _run_batch(self, batch_full_event: asyncio.Event, batch_finished_event: asyncio.Event): """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. @@ -222,15 +219,15 @@ async def _run_batch( # We allow other batches to start. self._current_batch_submitted.set() - # print(f"Running batch of length {len(batch)}") - - # Compute the embeddings - embeddings = await self._get_embeddings(batch) - for i in range(len(embeddings)): - self._req_results[batch_ids[i]] = embeddings[i] - - # Signal that the batch has finished processing - batch_finished_event.set() + try: + embeddings = await self._get_embeddings(batch) + for i in range(len(embeddings)): + self._req_results[batch_ids[i]] = embeddings[i] + except BaseException as exc: + for req_id in batch_ids: + self._req_errors[req_id] = exc + finally: + batch_finished_event.set() async def _batch_get_embeddings(self, text: str) -> List[float]: while True: @@ -240,14 +237,11 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: self._req_idx += 1 self._req_queue[req_id] = text - if ( - self._current_batch_finished_event is None - or self._current_batch_full_event is None - ): + if self._current_batch_finished_event is None or self._current_batch_full_event is None: self._current_batch_finished_event = asyncio.Event() self._current_batch_full_event = asyncio.Event() self._current_batch_submitted.clear() - asyncio.ensure_future( + asyncio.get_event_loop().create_task( self._run_batch( self._current_batch_full_event, self._current_batch_finished_event, @@ -257,9 +251,7 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: batch_finished_event = self._current_batch_finished_event batch_full_event = self._current_batch_full_event if batch_finished_event is None or batch_full_event is None: - raise RuntimeError( - "Batch events not initialized. This should not happen." - ) + raise RuntimeError("Batch events not initialized. This should not happen.") # We check if we reached the max batch size if len(self._req_queue) >= self.max_batch_size: @@ -274,15 +266,16 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: # Wait for the batch to finish await batch_finished_event.wait() + if req_id in self._req_errors: + raise self._req_errors.pop(req_id) + # Remove the result and return it result = self._req_results[req_id] del self._req_results[req_id] return result - async def search( - self, text: str, max_results: int = 20, threshold: Optional[float] = None - ) -> List[IndexItem]: + async def search(self, text: str, max_results: int = 20, threshold: Optional[float] = None) -> List[IndexItem]: """Search the closest `max_results` items. Args: @@ -301,9 +294,7 @@ async def search( _embedding = (await self._get_embeddings([text]))[0] if self._index is None: - raise ValueError( - "Index is not built yet. Ensure to call `build` before searching." - ) + raise ValueError("Index is not built yet. Ensure to call `build` before searching.") results = self._index.get_nns_by_vector( _embedding, @@ -324,14 +315,8 @@ async def search( return [self._items[i] for i in filtered_results] @staticmethod - def _filter_results( - indices: List[int], distances: List[float], threshold: float - ) -> List[int]: + def _filter_results(indices: List[int], distances: List[float], threshold: float) -> List[int]: if threshold == float("inf"): return indices else: - return [ - index - for index, distance in zip(indices, distances) - if (1 - distance / 2) >= threshold - ] + return [index for index, distance in zip(indices, distances) if (1 - distance / 2) >= threshold] diff --git a/tests/test_batch_embeddings.py b/tests/test_batch_embeddings.py index c037f8a52f..fd7c96b380 100644 --- a/tests/test_batch_embeddings.py +++ b/tests/test_batch_embeddings.py @@ -23,36 +23,39 @@ class MockEmbeddingModel: + def __init__(self): + self.call_count = 0 + async def encode_async(self, texts): + self.call_count += 1 await asyncio.sleep(0.01) return [[float(text.split()[-1])] for text in texts] @pytest.mark.asyncio async def test_batch_get_embeddings_handles_concurrent_batches(): + mock_model = MockEmbeddingModel() embeddings_index = BasicEmbeddingsIndex( use_batching=True, max_batch_size=2, max_batch_hold=0.01, ) - embeddings_index._model = MockEmbeddingModel() + embeddings_index._model = mock_model results = await asyncio.wait_for( - asyncio.gather( - *(embeddings_index._batch_get_embeddings(f"text {i}") for i in range(5)) - ), + asyncio.gather(*(embeddings_index._batch_get_embeddings(f"text {i}") for i in range(5))), timeout=1, ) assert sorted(result[0] for result in results) == [0, 1, 2, 3, 4] + # 5 requests with max_batch_size=2 must produce at least 3 separate batches + assert mock_model.call_count >= 3 @pytest.mark.skip(reason="Run manually.") @pytest.mark.asyncio async def test_search_speed(): - embeddings_index = BasicEmbeddingsIndex( - embedding_model="all-MiniLM-L6-v2", embedding_engine="SentenceTransformers" - ) + embeddings_index = BasicEmbeddingsIndex(embedding_model="all-MiniLM-L6-v2", embedding_engine="SentenceTransformers") # We compute an initial embedding, to warm up the model. await embeddings_index._get_embeddings(["warm up"]) @@ -101,9 +104,7 @@ async def _search(text): t0 = time() semaphore = asyncio.Semaphore(concurrency) for i in range(requests): - task = asyncio.ensure_future( - _search(f"This is a long sentence meant to mimic a user request {i}." * 5) - ) + task = asyncio.ensure_future(_search(f"This is a long sentence meant to mimic a user request {i}." * 5)) tasks.append(task) await asyncio.gather(*tasks) @@ -112,7 +113,5 @@ async def _search(text): print(f"Processing {completed_requests} took {took:0.2f}.") print(f"Completed {completed_requests} requests in {total_time:.2f} seconds.") - print( - f"Average latency: {total_time / completed_requests if completed_requests else 0:.2f} seconds." - ) + print(f"Average latency: {total_time / completed_requests if completed_requests else 0:.2f} seconds.") print(f"Maximum concurrency: {concurrency}") From e358c9d42ea4efe8c0a3b4b90ab9f87d517d2d71 Mon Sep 17 00:00:00 2001 From: ppradyoth <4ni19is062_a@nie.ac.in> Date: Fri, 22 May 2026 23:22:39 +0530 Subject: [PATCH 03/10] test: cover error propagation path in _batch_get_embeddings Adds test_batch_get_embeddings_propagates_model_error to exercise the except/finally branches added to _run_batch and the re-raise in _batch_get_embeddings, bringing codecov patch coverage to 100%. --- tests/test_batch_embeddings.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_batch_embeddings.py b/tests/test_batch_embeddings.py index fd7c96b380..29299d218f 100644 --- a/tests/test_batch_embeddings.py +++ b/tests/test_batch_embeddings.py @@ -32,6 +32,27 @@ async def encode_async(self, texts): return [[float(text.split()[-1])] for text in texts] +class FailingEmbeddingModel: + async def encode_async(self, texts): + raise RuntimeError("embedding model failure") + + +@pytest.mark.asyncio +async def test_batch_get_embeddings_propagates_model_error(): + embeddings_index = BasicEmbeddingsIndex( + use_batching=True, + max_batch_size=2, + max_batch_hold=0.01, + ) + embeddings_index._model = FailingEmbeddingModel() + + with pytest.raises(RuntimeError, match="embedding model failure"): + await asyncio.wait_for( + embeddings_index._batch_get_embeddings("text 0"), + timeout=1, + ) + + @pytest.mark.asyncio async def test_batch_get_embeddings_handles_concurrent_batches(): mock_model = MockEmbeddingModel() From b13239f71d06016943d092afa8f0a936bda32533 Mon Sep 17 00:00:00 2001 From: ppradyoth <4ni19is062_a@nie.ac.in> Date: Fri, 22 May 2026 23:26:30 +0530 Subject: [PATCH 04/10] fix: use asyncio.create_task and retain task reference Replaces asyncio.get_event_loop().create_task() with asyncio.create_task() (get_event_loop() deprecated in Python 3.10+) and stores the returned Task in self._current_batch_task so it is not silently discarded. --- nemoguardrails/embeddings/basic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index e769d2c74f..bdac42d911 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -90,6 +90,7 @@ def __init__( self._req_idx: int = 0 self._current_batch_finished_event: Optional[asyncio.Event] = None self._current_batch_full_event: Optional[asyncio.Event] = None + self._current_batch_task: Optional[asyncio.Task] = None self._current_batch_submitted: asyncio.Event = asyncio.Event() self._batch_lock: asyncio.Lock = asyncio.Lock() @@ -241,7 +242,7 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: self._current_batch_finished_event = asyncio.Event() self._current_batch_full_event = asyncio.Event() self._current_batch_submitted.clear() - asyncio.get_event_loop().create_task( + self._current_batch_task = asyncio.create_task( self._run_batch( self._current_batch_full_event, self._current_batch_finished_event, From 3792fb91bcdab7d3d3bde38b903df279dbd6472e Mon Sep 17 00:00:00 2001 From: ppradyoth <4ni19is062_a@nie.ac.in> Date: Fri, 22 May 2026 23:48:34 +0530 Subject: [PATCH 05/10] fix: guard against short embedding results and drop unused done variable - Rename unused done from asyncio.wait to _ - After _get_embeddings returns, check that the number of embeddings matches batch_ids; populate _req_errors for any missing slots so callers always receive either a result or an exception --- nemoguardrails/embeddings/basic.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index bdac42d911..29ac119096 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -194,7 +194,7 @@ async def _run_batch(self, batch_full_event: asyncio.Event, batch_finished_event """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. - done, pending = await asyncio.wait( + _, pending = await asyncio.wait( [ asyncio.create_task(asyncio.sleep(self.max_batch_hold)), asyncio.create_task(batch_full_event.wait()), @@ -222,11 +222,18 @@ async def _run_batch(self, batch_full_event: asyncio.Event, batch_finished_event try: embeddings = await self._get_embeddings(batch) + if len(embeddings) < len(batch_ids): + shortage_exc = RuntimeError( + f"Embedding model returned {len(embeddings)} embeddings for {len(batch_ids)} inputs." + ) + for req_id in batch_ids[len(embeddings) :]: + self._req_errors[req_id] = shortage_exc for i in range(len(embeddings)): self._req_results[batch_ids[i]] = embeddings[i] except BaseException as exc: for req_id in batch_ids: - self._req_errors[req_id] = exc + if req_id not in self._req_results and req_id not in self._req_errors: + self._req_errors[req_id] = exc finally: batch_finished_event.set() From 35650d4b3b2e7825b45041bfd4a8ecda6fa33448 Mon Sep 17 00:00:00 2001 From: ppradyoth <4ni19is062_a@nie.ac.in> Date: Sat, 23 May 2026 00:11:05 +0530 Subject: [PATCH 06/10] test: cover short embedding result path for full codecov --- tests/test_batch_embeddings.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_batch_embeddings.py b/tests/test_batch_embeddings.py index 29299d218f..1d518e0f45 100644 --- a/tests/test_batch_embeddings.py +++ b/tests/test_batch_embeddings.py @@ -37,6 +37,32 @@ async def encode_async(self, texts): raise RuntimeError("embedding model failure") +class ShortEmbeddingModel: + """Returns one fewer embedding than requested.""" + + async def encode_async(self, texts): + return [[float(i)] for i in range(len(texts) - 1)] + + +@pytest.mark.asyncio +async def test_batch_get_embeddings_propagates_short_result(): + embeddings_index = BasicEmbeddingsIndex( + use_batching=True, + max_batch_size=4, + max_batch_hold=0.01, + ) + embeddings_index._model = ShortEmbeddingModel() + + with pytest.raises(RuntimeError, match="fewer embeddings"): + await asyncio.wait_for( + asyncio.gather( + embeddings_index._batch_get_embeddings("text 0"), + embeddings_index._batch_get_embeddings("text 1"), + ), + timeout=1, + ) + + @pytest.mark.asyncio async def test_batch_get_embeddings_propagates_model_error(): embeddings_index = BasicEmbeddingsIndex( From f90e588cc55d86aae15a4036861154cb2e5612fe Mon Sep 17 00:00:00 2001 From: ppradyoth <4ni19is062_a@nie.ac.in> Date: Sat, 23 May 2026 00:11:43 +0530 Subject: [PATCH 07/10] fix: propagate CancelledError and prevent req slot leaks - Split BaseException handler into CancelledError (stores error, re-raises so task cancellation propagates correctly) and Exception (stores error without re-raising); batch_finished_event is still set via finally in both cases - Wrap the wait-and-retrieve block in _batch_get_embeddings with try/finally so _req_results and _req_errors entries are always cleaned up even when the caller is cancelled mid-wait --- nemoguardrails/embeddings/basic.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 29ac119096..5f3cdb0b7f 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -230,7 +230,12 @@ async def _run_batch(self, batch_full_event: asyncio.Event, batch_finished_event self._req_errors[req_id] = shortage_exc for i in range(len(embeddings)): self._req_results[batch_ids[i]] = embeddings[i] - except BaseException as exc: + except asyncio.CancelledError as exc: + for req_id in batch_ids: + if req_id not in self._req_results and req_id not in self._req_errors: + self._req_errors[req_id] = exc + raise + except Exception as exc: for req_id in batch_ids: if req_id not in self._req_results and req_id not in self._req_errors: self._req_errors[req_id] = exc @@ -271,15 +276,17 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: await batch_submitted_event.wait() - # Wait for the batch to finish - await batch_finished_event.wait() + # Wait for the batch to finish; clean up our slot regardless of how we exit. + try: + await batch_finished_event.wait() - if req_id in self._req_errors: - raise self._req_errors.pop(req_id) + if req_id in self._req_errors: + raise self._req_errors.pop(req_id) - # Remove the result and return it - result = self._req_results[req_id] - del self._req_results[req_id] + result = self._req_results.pop(req_id) + finally: + self._req_results.pop(req_id, None) + self._req_errors.pop(req_id, None) return result From 8b59599701965b8ed8597c9c5485357e967d94d0 Mon Sep 17 00:00:00 2001 From: ppradyoth <4ni19is062_a@nie.ac.in> Date: Sat, 23 May 2026 11:34:15 +0530 Subject: [PATCH 08/10] fix: correct test match pattern, harden _run_batch, cover CancelledError - Fix test_batch_get_embeddings_propagates_short_result: match pattern 'fewer embeddings' did not appear in the actual error message; corrected to 'Embedding model returned' (CI was failing on all Python versions) - Move try/finally to wrap the full _run_batch body (including the initial asyncio.wait and lock section) so batch_finished_event is guaranteed to be set regardless of where a cancellation or error fires; batch_ids initialised to [] before the try so except handlers never hit NameError - Add comment to _current_batch_task explaining it is stored for lifecycle control (e.g. cancellation, shutdown) - Add test_batch_get_embeddings_propagates_cancelled_batch_task: signals when encode_async starts, cancels the batch task at that point, and asserts CancelledError propagates to the caller; covers the previously uncovered except asyncio.CancelledError block --- nemoguardrails/embeddings/basic.py | 51 ++++++++++++++++-------------- tests/test_batch_embeddings.py | 28 +++++++++++++++- 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 5f3cdb0b7f..611f532045 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -90,6 +90,7 @@ def __init__( self._req_idx: int = 0 self._current_batch_finished_event: Optional[asyncio.Event] = None self._current_batch_full_event: Optional[asyncio.Event] = None + # Stored so callers can cancel or inspect the active batch task (e.g. shutdown). self._current_batch_task: Optional[asyncio.Task] = None self._current_batch_submitted: asyncio.Event = asyncio.Event() self._batch_lock: asyncio.Lock = asyncio.Lock() @@ -192,35 +193,37 @@ async def build(self): async def _run_batch(self, batch_full_event: asyncio.Event, batch_finished_event: asyncio.Event): """Runs the current batch of embeddings.""" + # Initialised here so the except handlers can safely iterate even if + # cancellation fires before the lock section populates batch_ids. + batch_ids: List[int] = [] + try: + # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. + _, pending = await asyncio.wait( + [ + asyncio.create_task(asyncio.sleep(self.max_batch_hold)), + asyncio.create_task(batch_full_event.wait()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() - # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. - _, pending = await asyncio.wait( - [ - asyncio.create_task(asyncio.sleep(self.max_batch_hold)), - asyncio.create_task(batch_full_event.wait()), - ], - return_when=asyncio.FIRST_COMPLETED, - ) - for task in pending: - task.cancel() - - async with self._batch_lock: - # Reset the active batch only if it has not already rolled over. - if self._current_batch_finished_event is batch_finished_event: - self._current_batch_finished_event = None - self._current_batch_full_event = None + async with self._batch_lock: + # Reset the active batch only if it has not already rolled over. + if self._current_batch_finished_event is batch_finished_event: + self._current_batch_finished_event = None + self._current_batch_full_event = None - # Create the actual batch to be sent for computing. - batch_ids = list(self._req_queue.keys()) - batch = [self._req_queue[req_id] for req_id in batch_ids] + # Create the actual batch to be sent for computing. + batch_ids = list(self._req_queue.keys()) + batch = [self._req_queue[req_id] for req_id in batch_ids] - # Empty the queue up to this point. - self._req_queue = {} + # Empty the queue up to this point. + self._req_queue = {} - # We allow other batches to start. - self._current_batch_submitted.set() + # We allow other batches to start. + self._current_batch_submitted.set() - try: embeddings = await self._get_embeddings(batch) if len(embeddings) < len(batch_ids): shortage_exc = RuntimeError( diff --git a/tests/test_batch_embeddings.py b/tests/test_batch_embeddings.py index 1d518e0f45..90979e49ef 100644 --- a/tests/test_batch_embeddings.py +++ b/tests/test_batch_embeddings.py @@ -53,7 +53,7 @@ async def test_batch_get_embeddings_propagates_short_result(): ) embeddings_index._model = ShortEmbeddingModel() - with pytest.raises(RuntimeError, match="fewer embeddings"): + with pytest.raises(RuntimeError, match="Embedding model returned"): await asyncio.wait_for( asyncio.gather( embeddings_index._batch_get_embeddings("text 0"), @@ -63,6 +63,32 @@ async def test_batch_get_embeddings_propagates_short_result(): ) +@pytest.mark.asyncio +async def test_batch_get_embeddings_propagates_cancelled_batch_task(): + """Cancelling the active _run_batch task must wake callers with CancelledError.""" + encoding_started = asyncio.Event() + + class HangingModel: + async def encode_async(self, texts): + encoding_started.set() + await asyncio.sleep(10) + + embeddings_index = BasicEmbeddingsIndex( + use_batching=True, + max_batch_size=10, + max_batch_hold=0.001, + ) + embeddings_index._model = HangingModel() + + caller = asyncio.create_task(embeddings_index._batch_get_embeddings("text 0")) + # Wait until _get_embeddings has been entered so cancellation hits that await. + await asyncio.wait_for(encoding_started.wait(), timeout=1) + embeddings_index._current_batch_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await caller + + @pytest.mark.asyncio async def test_batch_get_embeddings_propagates_model_error(): embeddings_index = BasicEmbeddingsIndex( From edb6464e9510aedf35013c28b94d08f66fd5903f Mon Sep 17 00:00:00 2001 From: ppradyoth <4ni19is062_a@nie.ac.in> Date: Sat, 23 May 2026 11:48:56 +0530 Subject: [PATCH 09/10] fix: handle early cancellation of _run_batch before batch_ids is populated MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If _run_batch is cancelled during the asyncio.wait (before acquiring the lock), batch_ids remains empty. Drain _req_queue without the lock — safe because asyncio is single-threaded and no await occurs before `raise`. Also unconditionally call _current_batch_submitted.set() in finally so queue-full waiters are never left blocked. Add test_batch_get_embeddings_propagates_early_cancellation to cover this path explicitly. --- nemoguardrails/embeddings/basic.py | 12 ++++++++++++ tests/test_batch_embeddings.py | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 611f532045..4f948f0eee 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -234,6 +234,15 @@ async def _run_batch(self, batch_full_event: asyncio.Event, batch_finished_event for i in range(len(embeddings)): self._req_results[batch_ids[i]] = embeddings[i] except asyncio.CancelledError as exc: + # If cancelled before the lock section, batch_ids is still []. + # Snapshot and drain the queue without the lock — safe because asyncio + # is single-threaded and no await occurs between here and `raise`. + if not batch_ids: + batch_ids = list(self._req_queue.keys()) + self._req_queue = {} + if self._current_batch_finished_event is batch_finished_event: + self._current_batch_finished_event = None + self._current_batch_full_event = None for req_id in batch_ids: if req_id not in self._req_results and req_id not in self._req_errors: self._req_errors[req_id] = exc @@ -243,6 +252,9 @@ async def _run_batch(self, batch_full_event: asyncio.Event, batch_finished_event if req_id not in self._req_results and req_id not in self._req_errors: self._req_errors[req_id] = exc finally: + # Unconditionally unblock full-queue waiters in case the lock section + # was never reached (early cancellation). + self._current_batch_submitted.set() batch_finished_event.set() async def _batch_get_embeddings(self, text: str) -> List[float]: diff --git a/tests/test_batch_embeddings.py b/tests/test_batch_embeddings.py index 90979e49ef..11fce8c925 100644 --- a/tests/test_batch_embeddings.py +++ b/tests/test_batch_embeddings.py @@ -89,6 +89,28 @@ async def encode_async(self, texts): await caller +@pytest.mark.asyncio +async def test_batch_get_embeddings_propagates_early_cancellation(): + """Cancelling _run_batch before batch_ids is populated must still wake callers.""" + embeddings_index = BasicEmbeddingsIndex( + use_batching=True, + max_batch_size=10, + max_batch_hold=10, # Long hold so cancellation fires during asyncio.wait + ) + embeddings_index._model = MockEmbeddingModel() + + caller = asyncio.create_task(embeddings_index._batch_get_embeddings("text 0")) + # Yield to let _batch_get_embeddings register the request and start _run_batch, + # then cancel before max_batch_hold elapses (i.e. before the lock section runs). + await asyncio.sleep(0) + await asyncio.sleep(0) + assert embeddings_index._current_batch_task is not None + embeddings_index._current_batch_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(caller, timeout=1) + + @pytest.mark.asyncio async def test_batch_get_embeddings_propagates_model_error(): embeddings_index = BasicEmbeddingsIndex( From 56e002fd14db7b35ce3d97ef82caa48a5e3cb6a9 Mon Sep 17 00:00:00 2001 From: ppradyoth <4ni19is062_a@nie.ac.in> Date: Sat, 23 May 2026 11:54:28 +0530 Subject: [PATCH 10/10] fix: raise RuntimeError instead of KeyError when result slot is missing Guard `_req_results.pop(req_id)` in `_batch_get_embeddings` so that if a req_id is absent from both `_req_results` and `_req_errors` when the batch event fires, callers get a clear RuntimeError rather than a confusing KeyError. Paired with the early-cancellation fix that ensures all in-flight request IDs are written to `_req_errors` before the event is set. --- nemoguardrails/embeddings/basic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 4f948f0eee..68b9a5ad19 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -298,6 +298,8 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: if req_id in self._req_errors: raise self._req_errors.pop(req_id) + if req_id not in self._req_results: + raise RuntimeError(f"Batch completed without a result for request {req_id}.") result = self._req_results.pop(req_id) finally: self._req_results.pop(req_id, None)