From 4fabed63c1f6383d3ca82c72b965777974268fce Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Wed, 13 May 2026 16:13:14 +0100 Subject: [PATCH] Overhaul self-check flows: unify action logic, provide per-task configuration, allow for multiple self check tasks --- .../guardrail-catalog/self-check.md | 93 +++ .../library/self_check/input_check/actions.py | 97 --- .../library/self_check/input_check/flows.co | 10 - .../self_check/input_check/flows.v1.co | 12 - .../__init__.py | 0 .../self_check/message_check/actions.py | 219 +++++++ .../library/self_check/message_check/flows.co | 21 + .../self_check/message_check/flows.v1.co | 28 + .../self_check/output_check/__init__.py | 14 - .../self_check/output_check/actions.py | 97 --- .../library/self_check/output_check/flows.co | 10 - .../self_check/output_check/flows.v1.co | 12 - tests/test_multiple_self_check_rails.py | 588 ++++++++++++++++++ 13 files changed, 949 insertions(+), 252 deletions(-) delete mode 100644 nemoguardrails/library/self_check/input_check/actions.py delete mode 100644 nemoguardrails/library/self_check/input_check/flows.co delete mode 100644 nemoguardrails/library/self_check/input_check/flows.v1.co rename nemoguardrails/library/self_check/{input_check => message_check}/__init__.py (100%) create mode 100644 nemoguardrails/library/self_check/message_check/actions.py create mode 100644 nemoguardrails/library/self_check/message_check/flows.co create mode 100644 nemoguardrails/library/self_check/message_check/flows.v1.co delete mode 100644 nemoguardrails/library/self_check/output_check/__init__.py delete mode 100644 nemoguardrails/library/self_check/output_check/actions.py delete mode 100644 nemoguardrails/library/self_check/output_check/flows.co delete mode 100644 nemoguardrails/library/self_check/output_check/flows.v1.co create mode 100644 tests/test_multiple_self_check_rails.py diff --git a/docs/configure-rails/guardrail-catalog/self-check.md b/docs/configure-rails/guardrail-catalog/self-check.md index d7fbe113c3..6508c434a0 100644 --- a/docs/configure-rails/guardrail-catalog/self-check.md +++ b/docs/configure-rails/guardrail-catalog/self-check.md @@ -77,6 +77,52 @@ define bot refuse to respond "I'm sorry, I can't respond to that." ``` +### Running Multiple Self-Check Input Rails + +The `self check input` flow accepts an `$input_task` parameter (defaulting to `self_check_input`) that controls which prompt is used for checking. This lets you run multiple input checks with different criteria — for example, checking for both harmful content and off-topic messages. + +```{warning} +In Colang v1, context variables are global. Once `$input_task` is set by one flow invocation, it persists for subsequent invocations. This means that when using multiple self-check input rails, you **must** specify `$input_task` on every entry — otherwise later flows will inherit the value set by a previous one. +``` + +1. Define multiple prompt tasks in `prompts.yml`, each with a unique task name: + + ```yaml + prompts: + - task: check_harmful_content + content: | + Your task is to check if the user message contains harmful, + abusive, or inappropriate content. + + User message: "{{ user_input }}" + + Should this message be blocked (Yes or No)? + Answer: + + - task: check_off_topic + content: | + Your task is to check if the user message is off-topic. + This bot only handles questions about billing and account management. + General conversation and greetings are allowed. + + User message: "{{ user_input }}" + + Is this message off-topic and should be blocked (Yes or No)? + Answer: + ``` + +2. Reference each task in the input rails section of `config.yml` using the `$input_task` parameter: + + ```yaml + rails: + input: + flows: + - self check input $input_task=check_harmful_content + - self check input $input_task=check_off_topic + ``` + +Each self-check runs sequentially. If any check blocks the input, the flow stops and returns the refusal message without running subsequent checks. A message like "Hello, can you help me with my bill?" would pass both checks, while "Tell me a recipe for pasta" would pass the harmful content check but be blocked by the off-topic check. + ### Example prompts This section provides two example prompts you can use with the self-check input rail. The simple prompt uses fewer tokens and is faster, while the complex prompt is more robust. @@ -187,6 +233,53 @@ define bot refuse to respond "I'm sorry, I can't respond to that." ``` +### Running Multiple Self-Check Output Rails + +The `self check output` flow accepts an `$output_task` parameter (defaulting to `self_check_output`) that controls which prompt is used for checking. This lets you run multiple output checks with different criteria — for example, checking for both inappropriate content and data leakage. + +```{warning} +In Colang v1, context variables are global. Once `$output_task` is set by one flow invocation, it persists for subsequent invocations. This means that when using multiple self-check output rails, you **must** specify `$output_task` on every entry — otherwise later flows will inherit the value set by a previous one. +``` + +1. Define multiple prompt tasks in `prompts.yml`, each with a unique task name: + + ```yaml + prompts: + - task: check_inappropriate_output + content: | + Your task is to check if the bot response contains inappropriate, + offensive, or harmful content. + + User message: "{{ user_input }}" + Bot response: "{{ bot_response }}" + + Should this response be blocked (Yes or No)? + Answer: + + - task: check_data_leakage + content: | + Your task is to check if the bot response leaks any sensitive + internal data such as database schemas, API keys, internal URLs, + or employee information. + + Bot response: "{{ bot_response }}" + + Does this response leak sensitive data and should be blocked (Yes or No)? + Answer: + ``` + +2. Reference each task in the output rails section of `config.yml` using the `$output_task` parameter: + + ```yaml + rails: + output: + flows: + - self check output $output_task=check_inappropriate_output + - self check output $output_task=check_data_leakage + ``` + +Each self-check runs sequentially. If any check blocks the output, the flow stops and returns the refusal message without running subsequent checks. + ### Example prompts This section provides two example prompts for the self-check output rail. The simple prompt uses fewer tokens and is faster, while the complex prompt is more robust. diff --git a/nemoguardrails/library/self_check/input_check/actions.py b/nemoguardrails/library/self_check/input_check/actions.py deleted file mode 100644 index 069b6f72e7..0000000000 --- a/nemoguardrails/library/self_check/input_check/actions.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Optional - -from nemoguardrails import RailsConfig -from nemoguardrails.actions.actions import ActionResult, action -from nemoguardrails.actions.llm.utils import llm_call, warn_if_truncated -from nemoguardrails.context import llm_call_info_var -from nemoguardrails.llm.taskmanager import LLMTaskManager -from nemoguardrails.llm.types import Task -from nemoguardrails.logging.explain import LLMCallInfo -from nemoguardrails.types import LLMModel -from nemoguardrails.utils import new_event_dict - -log = logging.getLogger(__name__) - - -@action(is_system_action=True) -async def self_check_input( - llm_task_manager: LLMTaskManager, - context: Optional[dict] = None, - llm: Optional[LLMModel] = None, - config: Optional[RailsConfig] = None, - **kwargs, -): - """Checks the input from the user. - - Prompt the LLM, using the `check_input` task prompt, to determine if the input - from the user should be allowed or not. - - Returns: - True if the input should be allowed, False otherwise. - """ - - _MAX_TOKENS = 1024 - user_input = context.get("user_message") - task = Task.SELF_CHECK_INPUT - - if user_input: - prompt = llm_task_manager.render_task_prompt( - task=task, - context={ - "user_input": user_input, - }, - ) - stop = llm_task_manager.get_stop_tokens(task=task) - max_tokens = llm_task_manager.get_max_tokens(task=task) - max_tokens = max_tokens or _MAX_TOKENS - - # Initialize the LLMCallInfo object - llm_call_info_var.set(LLMCallInfo(task=task.value)) - - llm_response = await llm_call( - llm, - prompt, - stop=stop, - llm_params={ - "temperature": config.lowest_temperature, - "max_tokens": max_tokens, - }, - ) - warn_if_truncated(llm_response, task.value) - response = llm_response.content - - log.info(f"Input self-checking result is: `{response}`.") - - # for sake of backward compatibility - # if the output_parser is not registered we will use the default one - if llm_task_manager.has_output_parser(task): - result = llm_task_manager.parse_task_output(task, output=response) - - else: - result = llm_task_manager.parse_task_output(task, output=response, forced_output_parser="is_content_safe") - - is_safe = result[0] - - if not is_safe: - return ActionResult( - return_value=False, - events=[new_event_dict("mask_prev_user_message", intent="unanswerable message")], - ) - - return is_safe diff --git a/nemoguardrails/library/self_check/input_check/flows.co b/nemoguardrails/library/self_check/input_check/flows.co deleted file mode 100644 index 2313d8866b..0000000000 --- a/nemoguardrails/library/self_check/input_check/flows.co +++ /dev/null @@ -1,10 +0,0 @@ - -flow self check input - $allowed = await SelfCheckInputAction - - if not $allowed - if $system.config.enable_rails_exceptions - send InputRailException(message="Input not allowed. The input was blocked by the 'self check input' flow.") - else - bot refuse to respond - abort diff --git a/nemoguardrails/library/self_check/input_check/flows.v1.co b/nemoguardrails/library/self_check/input_check/flows.v1.co deleted file mode 100644 index dbd04ea1e7..0000000000 --- a/nemoguardrails/library/self_check/input_check/flows.v1.co +++ /dev/null @@ -1,12 +0,0 @@ -define bot refuse to respond - "I'm sorry, I can't respond to that." - -define flow self check input - $allowed = execute self_check_input - - if not $allowed - if $config.enable_rails_exceptions - create event InputRailException(message="Input not allowed. The input was blocked by the 'self check input' flow.") - else - bot refuse to respond - stop diff --git a/nemoguardrails/library/self_check/input_check/__init__.py b/nemoguardrails/library/self_check/message_check/__init__.py similarity index 100% rename from nemoguardrails/library/self_check/input_check/__init__.py rename to nemoguardrails/library/self_check/message_check/__init__.py diff --git a/nemoguardrails/library/self_check/message_check/actions.py b/nemoguardrails/library/self_check/message_check/actions.py new file mode 100644 index 0000000000..86d4163585 --- /dev/null +++ b/nemoguardrails/library/self_check/message_check/actions.py @@ -0,0 +1,219 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Optional + +from nemoguardrails import RailsConfig +from nemoguardrails.actions.actions import ActionResult, action +from nemoguardrails.actions.llm.utils import llm_call, warn_if_truncated +from nemoguardrails.context import llm_call_info_var +from nemoguardrails.llm.taskmanager import LLMTaskManager +from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.types import LLMModel +from nemoguardrails.utils import new_event_dict + +log = logging.getLogger(__name__) + +DEFAULT_INPUT_TASK = "self_check_input" +DEFAULT_OUTPUT_TASK = "self_check_output" + +USER_MESSAGE = "user_message" +BOT_MESSAGE = "bot_message" + + +def _get_llm( + llms: Dict[str, LLMModel], task: str, default_task: str, default_llm: Optional[LLMModel] +) -> Optional[LLMModel]: + if task in llms: + return llms[task] + elif default_task in llms: + log.warning(f"No model found with type={task}, falling back to default {default_task} model") + return llms[default_task] + elif default_llm is not None: + log.warning(f"No model found with type={task} or type={default_task}, falling back to main model") + return default_llm + else: + error_msg = ( + f"No matching model for task={task} found." + f"Please configure a model with type={task}, type={default_task}, or type=main" + ) + raise ValueError(error_msg) + + +async def self_check( + llms: Dict[str, LLMModel], + message_type: str, + llm_task_manager: LLMTaskManager, + default_llm: Optional[LLMModel] = None, + context: Optional[dict] = None, + config: Optional[RailsConfig] = None, + task: str = DEFAULT_OUTPUT_TASK, + **kwargs, +): + _MAX_TOKENS = 1024 + bot_response = context.get(BOT_MESSAGE) + user_input = context.get(USER_MESSAGE) + bot_thinking = context.get("bot_thinking") + + if message_type == USER_MESSAGE: + default_task = DEFAULT_INPUT_TASK + elif message_type == BOT_MESSAGE: + default_task = DEFAULT_OUTPUT_TASK + else: + raise ValueError(f"Message type {message_type} not yet supported") + + # guard against an unset $task variable + if task.startswith("$") and task.endswith("_task"): + task = default_task + + # load model for this task, falling back to the default_task model or the injected llm if needed + llm = _get_llm(llms, task, default_task=default_task, default_llm=default_llm) + + # build prompts + if message_type == USER_MESSAGE: + if user_input: + prompt = llm_task_manager.render_task_prompt( + task=task, + context={ + "user_input": user_input, + }, + ) + + else: + raise ValueError("Self-check called on a user_message but no $user_message found in context") + elif message_type == BOT_MESSAGE: + if bot_response: + prompt = llm_task_manager.render_task_prompt( + task=task, + context={ + "user_input": user_input, + "bot_response": bot_response, + "bot_thinking": bot_thinking, + }, + ) + else: + raise ValueError("Self-check called on a bot_message but no $bot_message found in context") + + stop = llm_task_manager.get_stop_tokens(task=task) + max_tokens = llm_task_manager.get_max_tokens(task=task) + max_tokens = max_tokens or _MAX_TOKENS + + # Initialize the LLMCallInfo object + llm_call_info_var.set(LLMCallInfo(task=task)) + + llm_response = await llm_call( + llm, + prompt, + stop=stop, + llm_params={ + "temperature": config.lowest_temperature, + "max_tokens": max_tokens, + }, + ) + warn_if_truncated(llm_response, task) + response = llm_response.content + + # preserve log messages for backwards compatibility + if message_type == USER_MESSAGE: + if task == "self_check_input": + log.info(f"Input self-checking result is: `{response}`.") + else: + log.info(f"Input self-checking result for task={task} is: `{response}`.") + elif message_type == BOT_MESSAGE: + if task == "self_check_output": + log.info(f"Output self-checking result is: `{response}`.") + else: + log.info(f"Output self-checking result for task={task} is: `{response}`.") + + # for sake of backward compatibility + # if the output_parser is not registered we will use the default one + if llm_task_manager.has_output_parser(task): + result = llm_task_manager.parse_task_output(task, output=response) + else: + result = llm_task_manager.parse_task_output(task, output=response, forced_output_parser="is_content_safe") + + is_safe = result[0] + + if message_type == USER_MESSAGE and not is_safe: + return ActionResult( + return_value=False, + events=[new_event_dict("mask_prev_user_message", intent="unanswerable message")], + ) + return is_safe + + +@action(is_system_action=True) +async def self_check_input( + llms: Dict[str, LLMModel], + llm_task_manager: LLMTaskManager, + llm: Optional[LLMModel] = None, + context: Optional[dict] = None, + config: Optional[RailsConfig] = None, + task: str = DEFAULT_INPUT_TASK, + **kwargs, +): + """Checks the input from the user. + + Prompt the LLM, using the `check_input` task prompt, to determine if the input + from the user should be allowed or not. + + Returns: + True if the input should be allowed, False otherwise. + """ + return await self_check( + llms=llms, + message_type=USER_MESSAGE, + llm_task_manager=llm_task_manager, + default_llm=llm, + context=context, + config=config, + task=task, + **kwargs, + ) + + +@action(is_system_action=True, output_mapping=lambda value: not value) +async def self_check_output( + llms: Dict[str, LLMModel], + llm_task_manager: LLMTaskManager, + llm: Optional[LLMModel] = None, + context: Optional[dict] = None, + config: Optional[RailsConfig] = None, + task: str = DEFAULT_OUTPUT_TASK, + **kwargs, +): + """Checks if the output from the bot. + + Prompt the LLM, using the `self_check_output` task prompt, to determine if the output + from the bot should be allowed or not. + + The LLM call should return "yes" if the output is bad and should be blocked + (this is consistent with self_check_input_prompt). + + Returns: + True if the output should be allowed, False otherwise. + """ + + return await self_check( + llms=llms, + message_type=BOT_MESSAGE, + llm_task_manager=llm_task_manager, + default_llm=llm, + context=context, + config=config, + task=task, + **kwargs, + ) diff --git a/nemoguardrails/library/self_check/message_check/flows.co b/nemoguardrails/library/self_check/message_check/flows.co new file mode 100644 index 0000000000..16ea1cfcf7 --- /dev/null +++ b/nemoguardrails/library/self_check/message_check/flows.co @@ -0,0 +1,21 @@ + +flow self check input $input_task="self_check_input" + $allowed = await SelfCheckInputAction(task=$input_task) + + if not $allowed + if $system.config.enable_rails_exceptions + send InputRailException(message="Input not allowed. The input was blocked by the 'self check input' flow.") + else + bot refuse to respond + abort + + +flow self check output $output_task="self_check_output" + $allowed = await SelfCheckOutputAction(task=$output_task) + + if not $allowed + if $system.config.enable_rails_exceptions + send OutputRailException(message="Output not allowed. The output was blocked by the 'self check output' flow.") + else + bot refuse to respond + abort diff --git a/nemoguardrails/library/self_check/message_check/flows.v1.co b/nemoguardrails/library/self_check/message_check/flows.v1.co new file mode 100644 index 0000000000..53ec280cb3 --- /dev/null +++ b/nemoguardrails/library/self_check/message_check/flows.v1.co @@ -0,0 +1,28 @@ +define bot refuse to respond + "I'm sorry, I can't respond to that." + +define flow self check input + if not $input_task + $input_task = "self_check_input" + + $allowed = execute self_check_input(task=$input_task) + + if not $allowed + if $config.enable_rails_exceptions + create event InputRailException(message="Input not allowed. The input was blocked by the 'self check input' flow.") + else + bot refuse to respond + stop + +define flow self check output + if not $output_task + $output_task = "self_check_output" + + $allowed = execute self_check_output(task=$output_task) + + if not $allowed + if $config.enable_rails_exceptions + create event OutputRailException(message="Output not allowed. The output was blocked by the 'self check output' flow.") + else + bot refuse to respond + stop diff --git a/nemoguardrails/library/self_check/output_check/__init__.py b/nemoguardrails/library/self_check/output_check/__init__.py deleted file mode 100644 index 6c7f64065d..0000000000 --- a/nemoguardrails/library/self_check/output_check/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/nemoguardrails/library/self_check/output_check/actions.py b/nemoguardrails/library/self_check/output_check/actions.py deleted file mode 100644 index 47bbd161b9..0000000000 --- a/nemoguardrails/library/self_check/output_check/actions.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Optional - -from nemoguardrails import RailsConfig -from nemoguardrails.actions import action -from nemoguardrails.actions.llm.utils import llm_call, warn_if_truncated -from nemoguardrails.context import llm_call_info_var -from nemoguardrails.llm.taskmanager import LLMTaskManager -from nemoguardrails.llm.types import Task -from nemoguardrails.logging.explain import LLMCallInfo -from nemoguardrails.types import LLMModel - -log = logging.getLogger(__name__) - - -@action(is_system_action=True, output_mapping=lambda value: not value) -async def self_check_output( - llm_task_manager: LLMTaskManager, - context: Optional[dict] = None, - llm: Optional[LLMModel] = None, - config: Optional[RailsConfig] = None, - **kwargs, -): - """Checks if the output from the bot. - - Prompt the LLM, using the `self_check_output` task prompt, to determine if the output - from the bot should be allowed or not. - - The LLM call should return "yes" if the output is bad and should be blocked - (this is consistent with self_check_input_prompt). - - Returns: - True if the output should be allowed, False otherwise. - """ - - _MAX_TOKENS = 1024 - bot_response = context.get("bot_message") - user_input = context.get("user_message") - bot_thinking = context.get("bot_thinking") - - task = Task.SELF_CHECK_OUTPUT - - if bot_response: - prompt = llm_task_manager.render_task_prompt( - task=task, - context={ - "user_input": user_input, - "bot_response": bot_response, - "bot_thinking": bot_thinking, - }, - ) - stop = llm_task_manager.get_stop_tokens(task=task) - max_tokens = llm_task_manager.get_max_tokens(task=task) - max_tokens = max_tokens or _MAX_TOKENS - - # Initialize the LLMCallInfo object - llm_call_info_var.set(LLMCallInfo(task=task.value)) - - llm_response = await llm_call( - llm, - prompt, - stop=stop, - llm_params={ - "temperature": config.lowest_temperature, - "max_tokens": max_tokens, - }, - ) - warn_if_truncated(llm_response, task.value) - response = llm_response.content - - log.info(f"Output self-checking result is: `{response}`.") - - # for sake of backward compatibility - # if the output_parser is not registered we will use the default one - if llm_task_manager.has_output_parser(task): - result = llm_task_manager.parse_task_output(task, output=response) - else: - result = llm_task_manager.parse_task_output(task, output=response, forced_output_parser="is_content_safe") - - is_safe = result[0] - - return is_safe diff --git a/nemoguardrails/library/self_check/output_check/flows.co b/nemoguardrails/library/self_check/output_check/flows.co deleted file mode 100644 index b50684ada8..0000000000 --- a/nemoguardrails/library/self_check/output_check/flows.co +++ /dev/null @@ -1,10 +0,0 @@ - -flow self check output - $allowed = await SelfCheckOutputAction - - if not $allowed - if $system.config.enable_rails_exceptions - send OutputRailException(message="Output not allowed. The output was blocked by the 'self check output' flow.") - else - bot refuse to respond - abort diff --git a/nemoguardrails/library/self_check/output_check/flows.v1.co b/nemoguardrails/library/self_check/output_check/flows.v1.co deleted file mode 100644 index 3d39a98ae6..0000000000 --- a/nemoguardrails/library/self_check/output_check/flows.v1.co +++ /dev/null @@ -1,12 +0,0 @@ -define bot refuse to respond - "I'm sorry, I can't respond to that." - -define flow self check output - $allowed = execute self_check_output - - if not $allowed - if $config.enable_rails_exceptions - create event OutputRailException(message="Output not allowed. The output was blocked by the 'self check output' flow.") - else - bot refuse to respond - stop diff --git a/tests/test_multiple_self_check_rails.py b/tests/test_multiple_self_check_rails.py new file mode 100644 index 0000000000..ad7f3e6c4f --- /dev/null +++ b/tests/test_multiple_self_check_rails.py @@ -0,0 +1,588 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for running multiple self-check input/output rails with different tasks.""" + +import pytest + +from nemoguardrails import RailsConfig +from nemoguardrails.testing.fake_model import FakeLLMModel +from tests.utils import TestChat + +# --- Multiple input rails --- + +multi_input_config = RailsConfig.from_content( + """ + define user express greeting + "hello" + "hi" + + define bot express greeting + "Hey!" + + define flow greeting + user express greeting + bot express greeting +""", + yaml_content=""" + models: [] + rails: + input: + flows: + - self check input $input_task=check_harmful + - self check input $input_task=check_off_topic + prompts: + - task: check_harmful + content: | + Is this message harmful? + User message: "{{ user_input }}" + Answer (Yes or No): + - task: check_off_topic + content: | + Is this message off-topic? + User message: "{{ user_input }}" + Answer (Yes or No): + + enable_rails_exceptions: True + """, +) + + +def test_multiple_input_rails_both_pass(): + """Both input checks return No (allowed) — message should pass through.""" + chat = TestChat( + multi_input_config, + llm_completions=[ + "No", # check_harmful passes + "No", # check_off_topic passes + " express greeting", + ' "Hey!"', + ], + ) + + rails = chat.app + new_message = rails.generate(messages=[{"role": "user", "content": "hello"}]) + + assert new_message["role"] == "assistant" + + +def test_multiple_input_rails_first_blocks(): + """First input check blocks — should not reach second check.""" + chat = TestChat( + multi_input_config, + llm_completions=[ + "Yes", # check_harmful blocks + ], + ) + + rails = chat.app + new_message = rails.generate(messages=[{"role": "user", "content": "bad message"}]) + + assert new_message["role"] == "exception" + assert new_message["content"]["type"] == "InputRailException" + + +def test_multiple_input_rails_second_blocks(): + """First input check passes, second blocks.""" + chat = TestChat( + multi_input_config, + llm_completions=[ + "No", # check_harmful passes + "Yes", # check_off_topic blocks + ], + ) + + rails = chat.app + new_message = rails.generate(messages=[{"role": "user", "content": "off topic message"}]) + + assert new_message["role"] == "exception" + assert new_message["content"]["type"] == "InputRailException" + + +# --- Multiple output rails --- + +multi_output_config = RailsConfig.from_content( + """ + define user ask question + "tell me something" + + define flow + user ask question + bot respond +""", + yaml_content=""" + models: [] + rails: + output: + flows: + - self check output $output_task=check_inappropriate + - self check output $output_task=check_data_leakage + prompts: + - task: check_inappropriate + content: | + Is this response inappropriate? + Bot response: "{{ bot_response }}" + Answer (Yes or No): + - task: check_data_leakage + content: | + Does this response leak sensitive data? + Bot response: "{{ bot_response }}" + Answer (Yes or No): + + enable_rails_exceptions: True + """, +) + + +def test_multiple_output_rails_both_pass(): + """Both output checks return No (allowed) — LLM-generated response should pass through.""" + chat = TestChat( + multi_output_config, + llm_completions=[ + " ask question", + " Here is the answer.", + "No", # check_inappropriate passes + "No", # check_data_leakage passes + ], + ) + + rails = chat.app + new_message = rails.generate(messages=[{"role": "user", "content": "tell me something"}]) + + assert new_message["role"] == "assistant" + assert new_message["content"] == "Here is the answer." + + +def test_multiple_output_rails_first_blocks(): + """First output check blocks — should not reach second check.""" + chat = TestChat( + multi_output_config, + llm_completions=[ + " ask question", + ' "Some bad output"', + "Yes", # check_inappropriate blocks + ], + ) + + rails = chat.app + new_message = rails.generate(messages=[{"role": "user", "content": "tell me something"}]) + + assert new_message["role"] == "exception" + assert new_message["content"]["type"] == "OutputRailException" + + +def test_multiple_output_rails_second_blocks(): + """First output check passes, second blocks.""" + chat = TestChat( + multi_output_config, + llm_completions=[ + " ask question", + ' "Response with leaked data"', + "No", # check_inappropriate passes + "Yes", # check_data_leakage blocks + ], + ) + + rails = chat.app + new_message = rails.generate(messages=[{"role": "user", "content": "tell me something"}]) + + assert new_message["role"] == "exception" + assert new_message["content"]["type"] == "OutputRailException" + + +# --- Default task (backward compatibility) --- + +default_task_config = RailsConfig.from_content( + """ + define user ask question + "tell me something" + + define flow + user ask question + bot respond +""", + yaml_content=""" + models: [] + rails: + input: + flows: + - self check input + output: + flows: + - self check output + prompts: + - task: self_check_input + content: ... + - task: self_check_output + content: ... + + enable_rails_exceptions: True + """, +) + + +def test_default_task_input_still_works(): + """Self check input without $input_task should use default self_check_input task.""" + chat = TestChat( + default_task_config, + llm_completions=[ + "Yes", # blocks + ], + ) + + rails = chat.app + new_message = rails.generate(messages=[{"role": "user", "content": "bad input"}]) + + assert new_message["role"] == "exception" + assert new_message["content"]["type"] == "InputRailException" + + +def test_default_task_output_still_works(): + """Self check output without $output_task should use default self_check_output task.""" + chat = TestChat( + default_task_config, + llm_completions=[ + "No", # input passes + " ask question", + ' "Something that should be blocked"', + "Yes", # output blocks + ], + ) + + rails = chat.app + new_message = rails.generate(messages=[{"role": "user", "content": "tell me something"}]) + + assert new_message["role"] == "exception" + assert new_message["content"]["type"] == "OutputRailException" + + +# --- Per-task LLM configuration --- + +per_task_input_config = RailsConfig.from_content( + """ + define user express greeting + "hello" + "hi" + + define bot express greeting + "Hey!" + + define flow greeting + user express greeting + bot express greeting +""", + yaml_content=""" + models: [] + rails: + input: + flows: + - self check input $input_task=check_harmful + - self check input $input_task=check_off_topic + prompts: + - task: check_harmful + content: | + Is this message harmful? + User message: "{{ user_input }}" + Answer (Yes or No): + - task: check_off_topic + content: | + Is this message off-topic? + User message: "{{ user_input }}" + Answer (Yes or No): + + enable_rails_exceptions: True + """, +) + + +def test_per_task_llm_input_uses_task_specific_model(): + """Each input check task should use its own LLM when configured in the llms dict.""" + harmful_llm = FakeLLMModel(responses=["No"]) + off_topic_llm = FakeLLMModel(responses=["No"]) + + chat = TestChat( + per_task_input_config, + llm_completions=[ + " express greeting", + ' "Hey!"', + ], + ) + + rails = chat.app + rails.runtime.registered_action_params["llms"]["check_harmful"] = harmful_llm + rails.runtime.registered_action_params["llms"]["check_off_topic"] = off_topic_llm + + new_message = rails.generate(messages=[{"role": "user", "content": "hello"}]) + + assert new_message["role"] == "assistant" + assert harmful_llm.inference_count == 1 + assert off_topic_llm.inference_count == 1 + + +def test_per_task_llm_input_first_blocks_skips_second(): + """When the first per-task LLM blocks, the second task-specific LLM should not be called.""" + harmful_llm = FakeLLMModel(responses=["Yes"]) + off_topic_llm = FakeLLMModel(responses=["No"]) + + chat = TestChat( + per_task_input_config, + llm_completions=[], + ) + + rails = chat.app + rails.runtime.registered_action_params["llms"]["check_harmful"] = harmful_llm + rails.runtime.registered_action_params["llms"]["check_off_topic"] = off_topic_llm + + new_message = rails.generate(messages=[{"role": "user", "content": "bad message"}]) + + assert new_message["role"] == "exception" + assert new_message["content"]["type"] == "InputRailException" + assert harmful_llm.inference_count == 1 + assert off_topic_llm.inference_count == 0 + + +per_task_output_config = RailsConfig.from_content( + """ + define user ask question + "tell me something" + + define flow + user ask question + bot respond +""", + yaml_content=""" + models: [] + rails: + output: + flows: + - self check output $output_task=check_inappropriate + - self check output $output_task=check_data_leakage + prompts: + - task: check_inappropriate + content: | + Is this response inappropriate? + Bot response: "{{ bot_response }}" + Answer (Yes or No): + - task: check_data_leakage + content: | + Does this response leak sensitive data? + Bot response: "{{ bot_response }}" + Answer (Yes or No): + + enable_rails_exceptions: True + """, +) + + +def test_per_task_llm_output_uses_task_specific_model(): + """Each output check task should use its own LLM when configured in the llms dict.""" + inappropriate_llm = FakeLLMModel(responses=["No"]) + data_leakage_llm = FakeLLMModel(responses=["No"]) + + chat = TestChat( + per_task_output_config, + llm_completions=[ + " ask question", + " Here is the answer.", + ], + ) + + rails = chat.app + rails.runtime.registered_action_params["llms"]["check_inappropriate"] = inappropriate_llm + rails.runtime.registered_action_params["llms"]["check_data_leakage"] = data_leakage_llm + + new_message = rails.generate(messages=[{"role": "user", "content": "tell me something"}]) + + assert new_message["role"] == "assistant" + assert new_message["content"] == "Here is the answer." + assert inappropriate_llm.inference_count == 1 + assert data_leakage_llm.inference_count == 1 + + +def test_per_task_llm_output_first_blocks_skips_second(): + """When the first per-task output LLM blocks, the second should not be called.""" + inappropriate_llm = FakeLLMModel(responses=["Yes"]) + data_leakage_llm = FakeLLMModel(responses=["No"]) + + chat = TestChat( + per_task_output_config, + llm_completions=[ + " ask question", + ' "Some bad output"', + ], + ) + + rails = chat.app + rails.runtime.registered_action_params["llms"]["check_inappropriate"] = inappropriate_llm + rails.runtime.registered_action_params["llms"]["check_data_leakage"] = data_leakage_llm + + new_message = rails.generate(messages=[{"role": "user", "content": "tell me something"}]) + + assert new_message["role"] == "exception" + assert new_message["content"]["type"] == "OutputRailException" + assert inappropriate_llm.inference_count == 1 + assert data_leakage_llm.inference_count == 0 + + +def test_per_task_llm_falls_back_to_main_when_not_configured(): + """When no task-specific LLM is in the llms dict, it should fall back to the main LLM.""" + chat = TestChat( + per_task_input_config, + llm_completions=[ + "No", # check_harmful via main LLM + "No", # check_off_topic via main LLM + " express greeting", + ], + ) + + rails = chat.app + new_message = rails.generate(messages=[{"role": "user", "content": "hello"}]) + + assert new_message["role"] == "assistant" + assert chat.llm.inference_count == 3 + + +# --- Model fallback chain --- + + +def test_input_fallback_to_default_task_model(): + """When a custom task has no model, fall back to the self_check_input model.""" + default_llm = FakeLLMModel(responses=["No", "No"]) + + chat = TestChat( + per_task_input_config, + llm_completions=[ + " express greeting", + ], + ) + + rails = chat.app + rails.runtime.registered_action_params["llms"]["self_check_input"] = default_llm + + new_message = rails.generate(messages=[{"role": "user", "content": "hello"}]) + + assert new_message["role"] == "assistant" + assert default_llm.inference_count == 2 + assert chat.llm.inference_count == 1 + + +def test_output_fallback_to_default_task_model(): + """When a custom task has no model, fall back to the self_check_output model.""" + default_llm = FakeLLMModel(responses=["No", "No"]) + + chat = TestChat( + per_task_output_config, + llm_completions=[ + " ask question", + " Here is the answer.", + ], + ) + + rails = chat.app + rails.runtime.registered_action_params["llms"]["self_check_output"] = default_llm + + new_message = rails.generate(messages=[{"role": "user", "content": "tell me something"}]) + + assert new_message["role"] == "assistant" + assert new_message["content"] == "Here is the answer." + assert default_llm.inference_count == 2 + assert chat.llm.inference_count == 2 + + +def test_input_fallback_chain_prefers_task_over_default(): + """Task-specific model takes priority over default self_check_input model.""" + task_llm = FakeLLMModel(responses=["No"]) + default_llm = FakeLLMModel(responses=["No"]) + + chat = TestChat( + per_task_input_config, + llm_completions=[ + " express greeting", + ], + ) + + rails = chat.app + rails.runtime.registered_action_params["llms"]["self_check_input"] = default_llm + rails.runtime.registered_action_params["llms"]["check_harmful"] = task_llm + + new_message = rails.generate(messages=[{"role": "user", "content": "hello"}]) + + assert new_message["role"] == "assistant" + assert task_llm.inference_count == 1 + assert default_llm.inference_count == 1 + assert chat.llm.inference_count == 1 + + +def test_output_fallback_chain_prefers_task_over_default(): + """Task-specific model takes priority over default self_check_output model.""" + task_llm = FakeLLMModel(responses=["No"]) + default_llm = FakeLLMModel(responses=["No"]) + + chat = TestChat( + per_task_output_config, + llm_completions=[ + " ask question", + " Here is the answer.", + ], + ) + + rails = chat.app + rails.runtime.registered_action_params["llms"]["self_check_output"] = default_llm + rails.runtime.registered_action_params["llms"]["check_inappropriate"] = task_llm + + new_message = rails.generate(messages=[{"role": "user", "content": "tell me something"}]) + + assert new_message["role"] == "assistant" + assert task_llm.inference_count == 1 + assert default_llm.inference_count == 1 + assert chat.llm.inference_count == 2 + + +def test_input_no_model_raises_error(): + """When no model is available at any fallback level, _get_llm raises ValueError.""" + from nemoguardrails.library.self_check.message_check.actions import _get_llm + + with pytest.raises(ValueError, match="No matching model"): + _get_llm({}, "check_harmful", "self_check_input", default_llm=None) + + +def test_get_llm_fallback_chain(): + """_get_llm resolves models in order: task -> default_task -> main -> ValueError.""" + from nemoguardrails.library.self_check.message_check.actions import _get_llm + + task_llm = FakeLLMModel(responses=[]) + default_llm = FakeLLMModel(responses=[]) + main_llm = FakeLLMModel(responses=[]) + + all_llms = { + "check_harmful": task_llm, + "self_check_input": default_llm, + } + + # Level 1: exact task match + assert _get_llm(all_llms, "check_harmful", "self_check_input", default_llm=main_llm) is task_llm + + # Level 2: fall back to default task + assert _get_llm(all_llms, "check_off_topic", "self_check_input", default_llm=main_llm) is default_llm + + # Level 3: fall back to default_llm (the llm action param) + assert _get_llm({}, "check_harmful", "self_check_input", default_llm=main_llm) is main_llm + + # Level 4: no model raises ValueError + with pytest.raises(ValueError, match="No matching model"): + _get_llm({}, "check_harmful", "self_check_input", default_llm=None) + + with pytest.raises(ValueError, match="No matching model"): + _get_llm({}, "check_inappropriate", "self_check_output", default_llm=None)