diff --git a/examples/configs/nemotron/config.yml b/examples/configs/nemotron/config.yml
index 4bc306108..8e365b25c 100644
--- a/examples/configs/nemotron/config.yml
+++ b/examples/configs/nemotron/config.yml
@@ -2,5 +2,3 @@ models:
- type: main
engine: nim
model: nvidia/llama-3.1-nemotron-ultra-253b-v1
- reasoning_config:
- remove_reasoning_traces: False # Set True to remove traces from the internal tasks
diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py
index a230e5ce3..cb2add1f6 100644
--- a/nemoguardrails/actions/llm/generation.py
+++ b/nemoguardrails/actions/llm/generation.py
@@ -57,7 +57,7 @@
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
from nemoguardrails.kb.kb import KnowledgeBase
from nemoguardrails.llm.prompts import get_prompt
-from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput
+from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.llm.types import Task
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
@@ -496,7 +496,6 @@ async def generate_user_intent(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_USER_INTENT, output=result
)
- result = result.text
user_intent = get_first_nonempty_line(result)
if user_intent is None:
@@ -594,10 +593,6 @@ async def generate_user_intent(
Task.GENERAL, output=text
)
- text = _process_parsed_output(
- text, self._include_reasoning_traces()
- )
-
else:
# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
@@ -639,8 +634,6 @@ async def generate_user_intent(
text = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=result
)
-
- text = _process_parsed_output(text, self._include_reasoning_traces())
text = text.strip()
if text.startswith('"'):
text = text[1:-1]
@@ -750,7 +743,6 @@ async def generate_next_step(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_NEXT_STEPS, output=result
)
- result = result.text
# If we don't have multi-step generation enabled, we only look at the first line.
if not self.config.enable_multi_step_generation:
@@ -1036,10 +1028,6 @@ async def generate_bot_message(
Task.GENERAL, output=result
)
- result = _process_parsed_output(
- result, self._include_reasoning_traces()
- )
-
log.info(
"--- :: LLM Bot Message Generation passthrough call took %.2f seconds",
time() - t0,
@@ -1111,10 +1099,6 @@ async def generate_bot_message(
Task.GENERATE_BOT_MESSAGE, output=result
)
- result = _process_parsed_output(
- result, self._include_reasoning_traces()
- )
-
# TODO: catch openai.error.InvalidRequestError from exceeding max token length
result = get_multiline_response(result)
@@ -1212,7 +1196,6 @@ async def generate_value(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_VALUE, output=result
)
- result = result.text
# We only use the first line for now
# TODO: support multi-line values?
@@ -1433,7 +1416,6 @@ async def generate_intent_steps_message(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_INTENT_STEPS_MESSAGE, output=result
)
- result = result.text
# TODO: Implement logic for generating more complex Colang next steps (multi-step),
# not just a single bot intent.
@@ -1516,7 +1498,6 @@ async def generate_intent_steps_message(
result = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=result
)
- result = _process_parsed_output(result, self._include_reasoning_traces())
text = result.strip()
if text.startswith('"'):
text = text[1:-1]
@@ -1529,10 +1510,6 @@ async def generate_intent_steps_message(
events=[new_event_dict("BotMessage", text=text)],
)
- def _include_reasoning_traces(self) -> bool:
- """Get the configuration value for whether to include reasoning traces in output."""
- return _get_apply_to_reasoning_traces(self.config)
-
def clean_utterance_content(utterance: str) -> str:
"""
@@ -1550,27 +1527,3 @@ def clean_utterance_content(utterance: str) -> str:
# It should be translated to an actual \n character.
utterance = utterance.replace("\\n", "\n")
return utterance
-
-
-def _record_reasoning_trace(trace: str) -> None:
- """Store the reasoning trace in context for later retrieval."""
- reasoning_trace_var.set(trace)
-
-
-def _assemble_response(text: str, trace: Optional[str], include_reasoning: bool) -> str:
- """Combine trace and text if requested, otherwise just return text."""
- return (trace + text) if (trace and include_reasoning) else text
-
-
-def _process_parsed_output(
- output: ParsedTaskOutput, include_reasoning_trace: bool
-) -> str:
- """Record trace, then assemble the final LLM response."""
- if reasoning_trace := output.reasoning_trace:
- _record_reasoning_trace(reasoning_trace)
- return _assemble_response(output.text, reasoning_trace, include_reasoning_trace)
-
-
-def _get_apply_to_reasoning_traces(config: RailsConfig) -> bool:
- """Get the configuration value for whether to include reasoning traces in output."""
- return config.rails.output.apply_to_reasoning_traces
diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py
index c36899bb8..b71adb76c 100644
--- a/nemoguardrails/actions/llm/utils.py
+++ b/nemoguardrails/actions/llm/utils.py
@@ -179,7 +179,10 @@ def _store_tool_calls(response) -> None:
def _store_response_metadata(response) -> None:
- """Store response metadata excluding content for metadata preservation."""
+ """Store response metadata excluding content for metadata preservation.
+
+ Also extracts reasoning content from additional_kwargs if available from LangChain.
+ """
if hasattr(response, "model_fields"):
metadata = {}
for field_name in response.model_fields:
@@ -188,6 +191,16 @@ def _store_response_metadata(response) -> None:
): # Exclude content since it may be modified by rails
metadata[field_name] = getattr(response, field_name)
llm_response_metadata_var.set(metadata)
+
+ if hasattr(response, "additional_kwargs"):
+ additional_kwargs = response.additional_kwargs
+ if (
+ isinstance(additional_kwargs, dict)
+ and "reasoning_content" in additional_kwargs
+ ):
+ reasoning_content = additional_kwargs["reasoning_content"]
+ if reasoning_content:
+ reasoning_trace_var.set(reasoning_content)
else:
llm_response_metadata_var.set(None)
diff --git a/nemoguardrails/actions/v2_x/generation.py b/nemoguardrails/actions/v2_x/generation.py
index 72e703a2c..fbc20f0c7 100644
--- a/nemoguardrails/actions/v2_x/generation.py
+++ b/nemoguardrails/actions/v2_x/generation.py
@@ -318,8 +318,6 @@ async def generate_user_intent( # pyright: ignore (TODO - Signature completely
Task.GENERATE_USER_INTENT_FROM_USER_ACTION, output=result
)
- result = result.text
-
user_intent = get_first_nonempty_line(result)
# GTP-4o often adds 'user intent: ' in front
if user_intent and ":" in user_intent:
@@ -401,8 +399,6 @@ async def generate_user_intent_and_bot_action(
Task.GENERATE_USER_INTENT_AND_BOT_ACTION_FROM_USER_ACTION, output=result
)
- result = result.text
-
user_intent = get_first_nonempty_line(result)
if user_intent and ":" in user_intent:
@@ -578,8 +574,6 @@ async def generate_flow_from_instructions(
task=Task.GENERATE_FLOW_FROM_INSTRUCTIONS, output=result
)
- result = result.text
-
# TODO: why this is not part of a filter or output_parser?
#
lines = _remove_leading_empty_lines(result).split("\n")
@@ -660,8 +654,6 @@ async def generate_flow_from_name(
task=Task.GENERATE_FLOW_FROM_NAME, output=result
)
- result = result.text
-
lines = _remove_leading_empty_lines(result).split("\n")
if lines[0].startswith("flow"):
@@ -736,8 +728,6 @@ async def generate_flow_continuation(
task=Task.GENERATE_FLOW_CONTINUATION, output=result
)
- result = result.text
-
lines = _remove_leading_empty_lines(result).split("\n")
if len(lines) == 0 or (len(lines) == 1 and lines[0] == ""):
@@ -869,8 +859,6 @@ async def generate_value( # pyright: ignore (TODO - different arguments to base
Task.GENERATE_VALUE_FROM_INSTRUCTION, output=result
)
- result = result.text
-
# We only use the first line for now
# TODO: support multi-line values?
value = result.strip().split("\n")[0]
@@ -994,8 +982,6 @@ async def generate_flow(
Task.GENERATE_FLOW_CONTINUATION_FROM_NLD, output=result
)
- result = result.text
-
result = _remove_leading_empty_lines(result)
lines = result.split("\n")
if "codeblock" in lines[0]:
diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py
index cfda91cf1..2407210fa 100644
--- a/nemoguardrails/library/content_safety/actions.py
+++ b/nemoguardrails/library/content_safety/actions.py
@@ -83,7 +83,6 @@ async def content_safety_check_input(
)
result = llm_task_manager.parse_task_output(task, output=result)
- result = result.text
is_safe, *violated_policies = result
@@ -165,7 +164,6 @@ async def content_safety_check_output(
result = llm_task_manager.parse_task_output(task, output=result)
- result = result.text
is_safe, *violated_policies = result
return {"allowed": is_safe, "policy_violations": violated_policies}
diff --git a/nemoguardrails/library/self_check/facts/actions.py b/nemoguardrails/library/self_check/facts/actions.py
index c9cc6900b..ce5598864 100644
--- a/nemoguardrails/library/self_check/facts/actions.py
+++ b/nemoguardrails/library/self_check/facts/actions.py
@@ -85,7 +85,6 @@ async def self_check_facts(
task, output=response, forced_output_parser="is_content_safe"
)
- result = result.text
is_not_safe = result[0]
result = float(not is_not_safe)
diff --git a/nemoguardrails/library/self_check/input_check/actions.py b/nemoguardrails/library/self_check/input_check/actions.py
index fa6c6cc1f..8255258b2 100644
--- a/nemoguardrails/library/self_check/input_check/actions.py
+++ b/nemoguardrails/library/self_check/input_check/actions.py
@@ -87,7 +87,6 @@ async def self_check_input(
task, output=response, forced_output_parser="is_content_safe"
)
- result = result.text
is_safe = result[0]
if not is_safe:
diff --git a/nemoguardrails/library/self_check/output_check/actions.py b/nemoguardrails/library/self_check/output_check/actions.py
index d05b2659d..10d3ba340 100644
--- a/nemoguardrails/library/self_check/output_check/actions.py
+++ b/nemoguardrails/library/self_check/output_check/actions.py
@@ -91,7 +91,6 @@ async def self_check_output(
task, output=response, forced_output_parser="is_content_safe"
)
- result = result.text
is_safe = result[0]
return is_safe
diff --git a/nemoguardrails/llm/filters.py b/nemoguardrails/llm/filters.py
index a0d80bb5d..613a2399b 100644
--- a/nemoguardrails/llm/filters.py
+++ b/nemoguardrails/llm/filters.py
@@ -24,16 +24,6 @@
)
-@dataclass
-class ReasoningExtractionResult:
- """
- Holds cleaned response text and optional chain-of-thought reasoning trace extracted from LLM output.
- """
-
- text: str
- reasoning_trace: Optional[str] = None
-
-
def colang(events: List[dict]) -> str:
"""Filter that turns an array of events into a colang history."""
return get_colang_history(events)
@@ -448,100 +438,3 @@ def conversation_to_events(conversation: List) -> List[dict]:
)
return events
-
-
-def _find_token_positions_for_removal(
- response: str, start_token: Optional[str], end_token: Optional[str]
-) -> Tuple[int, int]:
- """Helper function to find token positions specifically for text removal.
-
- This is useful, for example, to remove reasoning traces from a reasoning LLM response.
-
- This is optimized for the removal use case:
- 1. Uses find() for first start token
- 2. Uses rfind() for last end token
- 3. Sets start_index to 0 if start token is missing
-
- Args:
- response(str): The text to search in
- start_token(str): The token marking the start of text to remove
- end_token(str): The token marking the end of text to remove
-
- Returns:
- A tuple of (start_index, end_index) marking the span to remove;
- both indices are -1 if start_token and end_token are not provided.
- """
- if not start_token or not end_token:
- return -1, -1
-
- start_index = response.find(start_token)
- # if the start index is missing, this is probably a continuation of a bot message
- # started in the prompt.
- if start_index == -1:
- start_index = 0
-
- end_index = response.rfind(end_token)
-
- return start_index, end_index
-
-
-def find_reasoning_tokens_position(
- response: str, start_token: Optional[str], end_token: Optional[str]
-) -> Tuple[int, int]:
- """Finds the positions of the first start token and the last end token.
-
- This is intended to find the outermost boundaries of potential
- reasoning sections, typically for removal.
-
- Args:
- response(str): The text to search in.
- start_token(Optional[str]): The token marking the start of reasoning.
- end_token(Optional[str]): The token marking the end of reasoning.
-
- Returns:
- A tuple (start_index, end_index).
- - start_index: Position of the first `start_token`, or 0 if not found.
- - end_index: Position of the last `end_token`, or -1 if not found.
- """
-
- return _find_token_positions_for_removal(response, start_token, end_token)
-
-
-def extract_and_strip_trace(
- response: str, start_token: str, end_token: str
-) -> ReasoningExtractionResult:
- """Extracts and removes reasoning traces from the given text.
-
- This function identifies reasoning traces in the text that are marked
- by specific start and end tokens. It extracts these traces, removes
- them from the original text, and returns both the cleaned text and
- the extracted reasoning trace.
-
- Args:
- response (str): The text to process.
- start_token (str): The token marking the start of a reasoning trace.
- end_token (str): The token marking the end of a reasoning trace.
-
- Returns:
- ReasoningExtractionResult: An object containing the cleaned text
- without reasoning traces and the extracted reasoning trace, if any.
- """
-
- start_index, end_index = find_reasoning_tokens_position(
- response, start_token, end_token
- )
- # handles invalid/empty tokens returned as (-1, -1)
- if start_index == -1 and end_index == -1:
- return ReasoningExtractionResult(text=response, reasoning_trace=None)
- # end token is missing
- if end_index == -1:
- return ReasoningExtractionResult(text=response, reasoning_trace=None)
- # extrace if tokens are present and start < end
- if start_index < end_index:
- reasoning_trace = response[start_index : end_index + len(end_token)]
- cleaned_text = response[:start_index] + response[end_index + len(end_token) :]
- return ReasoningExtractionResult(
- text=cleaned_text, reasoning_trace=reasoning_trace
- )
-
- return ReasoningExtractionResult(text=response, reasoning_trace=None)
diff --git a/nemoguardrails/llm/taskmanager.py b/nemoguardrails/llm/taskmanager.py
index 3651676db..bf3598d62 100644
--- a/nemoguardrails/llm/taskmanager.py
+++ b/nemoguardrails/llm/taskmanager.py
@@ -23,12 +23,10 @@
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
-from nemoguardrails.actions.llm.utils import get_and_clear_reasoning_trace_contextvar
from nemoguardrails.llm.filters import (
co_v2,
colang,
colang_without_identifiers,
- extract_and_strip_trace,
first_turns,
indent,
last_turns,
@@ -55,56 +53,6 @@
from nemoguardrails.rails.llm.config import MessageTemplate, RailsConfig
-def output_has_reasoning_traces(output: str, start_token: str, end_token: str) -> bool:
- """Checks if the output string contains both start and end reasoning tokens."""
- return start_token in output and end_token in output
-
-
-@dataclass
-class ParsedTaskOutput:
- """
- Encapsulates the result of running and parsing an LLM task.
-
- Attributes:
- text (str): The cleaned and parsed output string, representing
- the main result of the task.
- reasoning_trace (Optional[str]): An optional chain-of-thought
- reasoning trace, providing insights into the reasoning
- process behind the task output, if available.
- """
-
- text: str
- reasoning_trace: Optional[str] = None
-
-
-def should_remove_reasoning_traces_from_output(config, task):
- model = get_task_model(config, task)
-
- model_config = (
- model
- and model.reasoning_config
- and model.reasoning_config.remove_reasoning_traces
- )
-
- if config.rails.output.apply_to_reasoning_traces:
- return False
- else:
- return model_config
-
-
-def get_reasoning_token_tags(config, task):
- model = get_task_model(config, task)
-
- if model and model.reasoning_config:
- start_token = model.reasoning_config.start_token
- end_token = model.reasoning_config.end_token
- else:
- start_token = None
- end_token = None
-
- return start_token, end_token
-
-
class LLMTaskManager:
"""Interface for interacting with an LLM in a task-oriented way."""
@@ -156,70 +104,6 @@ def _get_general_instructions(self):
return text
- def _preprocess_events_for_prompt(
- self, events: Optional[List[dict]]
- ) -> Optional[List[dict]]:
- """Remove reasoning traces from bot messages before rendering them in prompts.
-
- This prevents reasoning traces from being included in LLM prompt history when
- rails.output.apply_to_reasoning_traces=true is enabled.
-
- Args:
- events: The list of events to preprocess
-
- Returns:
- A new list of preprocessed events, or None if events was None
- """
- if not events:
- return None
-
- processed_events = copy.deepcopy(events)
-
- for event in processed_events:
- if (
- isinstance(event, dict)
- and event.get("type") == "BotMessage"
- and "text" in event
- ):
- bot_utterance = event["text"]
- for task in Task:
- start_token, end_token = get_reasoning_token_tags(self.config, task)
- if (
- start_token
- and end_token
- and output_has_reasoning_traces(
- bot_utterance, start_token, end_token
- )
- ):
- result = extract_and_strip_trace(
- bot_utterance, start_token, end_token
- )
- event["text"] = result.text
- break
-
- elif (
- isinstance(event, dict)
- and event.get("type") == "StartUtteranceBotAction"
- and "script" in event
- ):
- bot_utterance = event["script"]
- for task in Task:
- start_token, end_token = get_reasoning_token_tags(self.config, task)
- if (
- start_token
- and end_token
- and output_has_reasoning_traces(
- bot_utterance, start_token, end_token
- )
- ):
- result = extract_and_strip_trace(
- bot_utterance, start_token, end_token
- )
- event["script"] = result.text
- break
-
- return processed_events
-
def _render_string(
self,
template_str: str,
@@ -234,9 +118,6 @@ def _render_string(
:return: The rendered template.
:rtype: str.
"""
- # Preprocess events to remove reasoning traces from BotMessage events
- processed_events = self._preprocess_events_for_prompt(events)
-
template = self.env.from_string(template_str)
# First, we extract all the variables from the template.
@@ -244,7 +125,7 @@ def _render_string(
# This is the context that will be passed to the template when rendering.
render_context = {
- "history": processed_events,
+ "history": events,
"general_instructions": self._get_general_instructions(),
"sample_conversation": self.config.sample_conversation,
"sample_conversation_two_turns": self.config.sample_conversation,
@@ -426,8 +307,8 @@ def render_task_prompt(
def parse_task_output(
self, task: Task, output: str, forced_output_parser: Optional[str] = None
- ) -> ParsedTaskOutput:
- """Parses the output of a task, optionally extracting reasoning traces.
+ ) -> str:
+ """Parses the output of a task using the configured output parser.
Args:
task (Task): The task for which the output is being parsed.
@@ -435,30 +316,8 @@ def parse_task_output(
forced_output_parser (Optional[str]): An optional parser name to force
Returns:
- ParsedTaskOutput: An object containing the parsed text (which may
- include or exclude reasoning traces based on configuration) and
- any reasoning trace.
+ str: The parsed text output.
"""
- reasoning_trace: Optional[str] = None
-
- # Get the tokens first to check for their presence
- start_token, end_token = get_reasoning_token_tags(self.config, task)
-
- # 1. strip and capture reasoning traces if configured and present
- if (
- start_token
- and end_token
- and output_has_reasoning_traces(output, start_token, end_token)
- ):
- reasoning_trace_result = extract_and_strip_trace(
- output, start_token, end_token
- )
- reasoning_trace = reasoning_trace_result.reasoning_trace
-
- if should_remove_reasoning_traces_from_output(self.config, task):
- output = reasoning_trace_result.text
-
- # 2. delegate to existing parser
prompt = get_prompt(self.config, task)
parser_name = forced_output_parser or prompt.output_parser
parser_fn = self.output_parsers.get(parser_name)
@@ -469,7 +328,7 @@ def parse_task_output(
logging.info("No output parser found for %s", prompt.output_parser)
parsed_text = output
- return ParsedTaskOutput(text=parsed_text, reasoning_trace=reasoning_trace)
+ return parsed_text
def has_output_parser(self, task: Task):
prompt = get_prompt(self.config, task)
diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py
index 749ecfd32..eafa6902c 100644
--- a/nemoguardrails/rails/llm/config.py
+++ b/nemoguardrails/rails/llm/config.py
@@ -69,34 +69,6 @@
colang_path_dirs.append(guardrails_stdlib_path)
-class ReasoningModelConfig(BaseModel):
- """Configuration for reasoning models/LLMs, including start and end tokens for reasoning traces."""
-
- remove_reasoning_traces: Optional[bool] = Field(
- default=True,
- description="For reasoning models (e.g. DeepSeek-r1), if the output parser should remove reasoning traces.",
- )
- remove_thinking_traces: Optional[bool] = Field(
- default=None,
- deprecated="The `remove_thinking_traces` field is deprecated use remove_reasoning_traces instead.",
- )
- start_token: Optional[str] = Field(
- default="",
- description="The start token used for reasoning traces.",
- )
- end_token: Optional[str] = Field(
- default="",
- description="The end token used for reasoning traces.",
- )
-
- @model_validator(mode="after")
- def _migrate_thinking_traces(self) -> "ReasoningModelConfig":
- # If someone uses the old field, propagate it silently
- if self.remove_thinking_traces is not None:
- self.remove_reasoning_traces = self.remove_thinking_traces
- return self
-
-
class Model(BaseModel):
"""Configuration of a model used by the rails engine.
@@ -118,10 +90,6 @@ class Model(BaseModel):
default=None,
description='Optional environment variable with model\'s API Key. Do not include "$".',
)
- reasoning_config: Optional[ReasoningModelConfig] = Field(
- default_factory=ReasoningModelConfig,
- description="Configuration parameters for reasoning LLMs.",
- )
parameters: Dict[str, Any] = Field(default_factory=dict)
mode: Literal["chat", "text"] = Field(
@@ -492,15 +460,6 @@ class OutputRails(BaseModel):
description="Configuration for streaming output rails.",
)
- apply_to_reasoning_traces: bool = Field(
- default=False,
- description=(
- "If True, output rails will apply guardrails to both reasoning traces and output response. "
- "If False, output rails will only apply guardrails to the output response excluding the reasoning traces, "
- "thus keeping reasoning traces unaltered."
- ),
- )
-
class RetrievalRails(BaseModel):
"""Configuration of retrieval rails."""
@@ -1438,80 +1397,6 @@ class RailsConfig(BaseModel):
description="Configuration for tracing.",
)
- @root_validator(pre=True, allow_reuse=True)
- def check_reasoning_traces_with_dialog_rails(cls, values):
- """Check that reasoning traces are not enabled when dialog rails are present."""
-
- models = values.get("models", [])
- rails = values.get("rails", {})
- dialog_rails = rails.get("dialog", {})
-
- # dialog rail tasks that should not have reasoning traces
- dialog_rail_tasks = [
- # Task.GENERATE_BOT_MESSAGE,
- Task.GENERATE_USER_INTENT,
- Task.GENERATE_NEXT_STEPS,
- Task.GENERATE_INTENT_STEPS_MESSAGE,
- ]
-
- embeddings_only = dialog_rails.get("user_messages", {}).get(
- "embeddings_only", False
- )
-
- has_dialog_rail_configs = (
- bool(values.get("user_messages"))
- or bool(values.get("bot_messages"))
- or bool(values.get("flows"))
- )
-
- # dialog rails are activated (explicitly or implicitly) and require validation
- # skip validation when embeddings_only is True
- has_dialog_rails = (
- bool(dialog_rails) or has_dialog_rail_configs
- ) and not embeddings_only
-
- if has_dialog_rails:
- main_model = next(
- (model for model in models if model.get("type") == "main"), None
- )
-
- violations = []
-
- for task in dialog_rail_tasks:
- task_model = next(
- (model for model in models if model.get("type") == task.value), None
- )
-
- if task_model:
- reasoning_config = (
- task_model.reasoning_config
- if hasattr(task_model, "reasoning_config")
- else task_model.get("reasoning_config", {})
- )
- if not reasoning_config.get("remove_reasoning_traces", True):
- violations.append(
- f"Model '{task_model.get('type')}' has reasoning traces enabled in config.yml. "
- f"Reasoning traces must be disabled for dialog rail tasks. "
- f"Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config for this model."
- )
- elif main_model:
- reasoning_config = (
- main_model.reasoning_config
- if hasattr(main_model, "reasoning_config")
- else main_model.get("reasoning_config", {})
- )
- if not reasoning_config.get("remove_reasoning_traces", True):
- violations.append(
- f"Main model has reasoning traces enabled in config.yml and is being used for dialog rail task '{task.value}'. "
- f"Reasoning traces must be disabled when dialog rails are present. "
- f"Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config for the main model."
- )
-
- if violations:
- raise ValueError("\n".join(violations))
-
- return values
-
@root_validator(pre=True, allow_reuse=True)
def check_prompt_exist_for_self_check_rails(cls, values):
rails = values.get("rails", {})
diff --git a/tests/test_config_validation.py b/tests/test_config_validation.py
index 36cf9b303..adcbd37e9 100644
--- a/tests/test_config_validation.py
+++ b/tests/test_config_validation.py
@@ -122,518 +122,3 @@ def test_passthrough_and_single_call_incompatibility():
# LLMRails(config=config)
#
# assert "You must provide a `self_check_facts` prompt" in str(exc_info.value)
-
-
-def test_reasoning_traces_with_explicit_dialog_rails():
- """Test that reasoning traces cannot be enabled when dialog rails are explicitly configured."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- rails:
- dialog:
- single_call:
- enabled: true
- """,
- )
-
- assert "Main model has reasoning traces enabled in config.yml" in str(
- exc_info.value
- )
- assert "Reasoning traces must be disabled when dialog rails are present" in str(
- exc_info.value
- )
- assert (
- "Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config"
- in str(exc_info.value)
- )
-
-
-def test_reasoning_traces_without_dialog_rails():
- """Test that reasoning traces can be enabled when no dialog rails are present."""
-
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- """,
- )
-
-
-def test_dialog_rails_without_reasoning_traces():
- """Test that dialog rails can be enabled when reasoning traces are not enabled."""
-
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- rails:
- dialog:
- single_call:
- enabled: true
- """,
- )
-
-
-def test_input_rails_only_no_dialog_rails():
- config = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- rails:
- input:
- flows:
- - self check input
- prompts:
- - task: self_check_input
- content: "Check if input is safe"
- """,
- )
-
- assert not config.user_messages
- assert not config.bot_messages
- assert not config.flows
- assert not config.rails.dialog.single_call.enabled
-
-
-def test_no_dialog_tasks_with_only_output_rails():
- """Test that dialog tasks are not used when only output rails are present."""
-
- config = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- rails:
- output:
- flows:
- - self check output
- prompts:
- - task: self_check_output
- content: "Check if output is safe"
- """,
- )
-
- assert not config.user_messages
- assert not config.bot_messages
- assert not config.flows
- assert not config.rails.dialog.single_call.enabled
-
-
-def test_reasoning_traces_with_implicit_dialog_rails_user_bot_messages():
- """Test that reasoning traces cannot be enabled when dialog rails are implicitly enabled thru user/bot messages."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- """,
- colang_content="""
- define user express greeting
- "hello"
- "hi"
-
- define bot express greeting
- "Hello there!"
-
- define flow
- user express greeting
- bot express greeting
- """,
- )
-
- assert "Main model has reasoning traces enabled in config.yml" in str(
- exc_info.value
- )
- assert "Reasoning traces must be disabled when dialog rails are present" in str(
- exc_info.value
- )
- assert (
- "Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config"
- in str(exc_info.value)
- )
-
-
-def test_reasoning_traces_with_implicit_dialog_rails_flows_only():
- """Test that reasoning traces cannot be enabled when dialog rails are implicitly enabled thru flows only."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: False
- """,
- colang_content="""
- define flow
- user express greeting
- bot express greeting
- define user express greeting
- "hi"
- define bot express greeting
- "HI HI"
- """,
- )
-
- assert "Main model has reasoning traces enabled in config.yml" in str(
- exc_info.value
- )
- assert "Reasoning traces must be disabled when dialog rails are present" in str(
- exc_info.value
- )
- assert (
- "Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config"
- in str(exc_info.value)
- )
-
-
-def test_reasoning_traces_with_implicit_dialog_rails_user_messages_only():
- """Test that reasoning traces cannot be enabled when dialog rails are implicitly enabled through user messages (user canonical forms) only."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- """,
- colang_content="""
- define user express greeting
- "hello"
- "hi"
- """,
- )
-
- assert "Reasoning traces must be disabled when dialog rails are present" in str(
- exc_info.value
- )
-
-
-def test_reasoning_traces_with_bot_messages_only():
- """Test that reasoning traces cannot be enabled when bot messages are present."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: False
- """,
- colang_content="""
- define bot express greeting
- "Hello there!"
- """,
- )
-
- assert "Reasoning traces must be disabled when dialog rails are present" in str(
- exc_info.value
- )
-
-
-def test_reasoning_traces_with_dedicated_task_models():
- """Test that reasoning traces cannot be enabled for dedicated task models when dialog rails are present."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- - type: generate_bot_message
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- - type: generate_user_intent
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- rails:
- dialog:
- single_call:
- enabled: true
- """,
- )
-
- assert (
- "Model 'generate_user_intent' has reasoning traces enabled in config.yml"
- in str(exc_info.value)
- )
- assert "Reasoning traces must be disabled for dialog rail tasks" in str(
- exc_info.value
- )
- assert (
- "Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config"
- in str(exc_info.value)
- )
-
-
-def test_reasoning_traces_with_mixed_task_models():
- """Test that reasoning traces cannot be enabled for any task model when dialog rails are present."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- - type: generate_bot_message
- engine: openai
- model: gpt-3.5-turbo-instruct
- - type: generate_user_intent
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- rails:
- dialog:
- single_call:
- enabled: true
- """,
- )
-
- assert (
- "Model 'generate_user_intent' has reasoning traces enabled in config.yml"
- in str(exc_info.value)
- )
- assert "Reasoning traces must be disabled for dialog rail tasks" in str(
- exc_info.value
- )
- assert (
- "Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config"
- in str(exc_info.value)
- )
-
-
-def test_reasoning_traces_with_all_dialog_tasks():
- """Test that reasoning traces cannot be enabled for any dialog rail task."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- - type: generate_bot_message
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- - type: generate_user_intent
- engine: openai
- model: gpt-3.5-turbo-instruct
- - type: generate_next_steps
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- - type: generate_intent_steps_message
- engine: openai
- model: gpt-3.5-turbo-instruct
- rails:
- dialog:
- single_call:
- enabled: true
- """,
- )
-
- error_message = str(exc_info.value)
- assert (
- "Model 'generate_bot_message' has reasoning traces enabled in config.yml"
- not in error_message
- )
- assert (
- "Model 'generate_next_steps' has reasoning traces enabled in config.yml"
- in error_message
- )
- assert "Reasoning traces must be disabled for dialog rail tasks" in error_message
- assert (
- "Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config"
- in error_message
- )
-
-
-def test_reasoning_traces_with_dedicated_models_no_dialog_rails():
- """Test that reasoning traces can be enabled for dedicated models when no dialog rails are present."""
-
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- - type: generate_bot_message
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- - type: generate_user_intent
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- """,
- )
-
-
-def test_reasoning_traces_with_implicit_dialog_rails_and_dedicated_models():
- """Test that reasoning traces cannot be enabled for dedicated models when dialog rails are implicitly enabled."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- - type: generate_user_intent
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- """,
- colang_content="""
- define user express greeting
- "hello"
- "hi"
-
- define bot express greeting
- "Hello there!"
-
- define flow
- user express greeting
- bot express greeting
- """,
- )
-
- assert (
- "Model 'generate_user_intent' has reasoning traces enabled in config.yml"
- in str(exc_info.value)
- )
- assert "Reasoning traces must be disabled for dialog rail tasks" in str(
- exc_info.value
- )
- assert (
- "Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config"
- in str(exc_info.value)
- )
-
-
-def test_reasoning_traces_with_partial_dedicated_models():
- """Test that reasoning traces cannot be enabled for any model when some tasks use dedicated models and others fall back to main."""
-
- with pytest.raises(ValueError) as exc_info:
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: false
- - type: generate_bot_message
- engine: openai
- model: gpt-3.5-turbo-instruct
- rails:
- dialog:
- single_call:
- enabled: true
- """,
- )
-
- assert "Main model has reasoning traces enabled in config.yml" in str(
- exc_info.value
- )
- assert "Reasoning traces must be disabled when dialog rails are present" in str(
- exc_info.value
- )
- assert (
- "Please update your config.yml to set 'remove_reasoning_traces: true' under reasoning_config"
- in str(exc_info.value)
- )
-
-
-def test_reasoning_traces_with_implicit_dialog_rails_embeddings_only():
- """Test that reasoning traces can be enabled when embeddings_only is True, even with user messages."""
-
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: False
- rails:
- dialog:
- user_messages:
- embeddings_only: True
- """,
- colang_content="""
- define user express greeting
- "hello"
- "hi"
- """,
- )
-
-
-def test_reasoning_traces_with_bot_messages_embeddings_only():
- """Test that reasoning traces can be enabled when embeddings_only is True, even with bot messages."""
-
- _ = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_reasoning_traces: False
- rails:
- dialog:
- user_messages:
- embeddings_only: True
- """,
- colang_content="""
- define bot express greeting
- "Hello there!"
- """,
- )
diff --git a/tests/test_content_safety_actions.py b/tests/test_content_safety_actions.py
index 12ebf06b0..fd934d10e 100644
--- a/tests/test_content_safety_actions.py
+++ b/tests/test_content_safety_actions.py
@@ -88,9 +88,7 @@ async def test_content_safety_parsing(
expected_violations,
):
llms = fake_llm("irrelevant")
- mock_parsed = MagicMock()
- mock_parsed.text = parsed_text
- mock_task_manager.parse_task_output.return_value = mock_parsed
+ mock_task_manager.parse_task_output.return_value = parsed_text
result = await check_fn(
llms=llms,
diff --git a/tests/test_content_safety_integration.py b/tests/test_content_safety_integration.py
index 52702f19c..d0485e0f2 100644
--- a/tests/test_content_safety_integration.py
+++ b/tests/test_content_safety_integration.py
@@ -40,13 +40,11 @@ def _create_mock_setup(llm_responses, parsed_result):
llms = {"test_model": mock_llm}
mock_task_manager = MagicMock()
- mock_parsed_result = MagicMock()
- mock_parsed_result.text = parsed_result
mock_task_manager.render_task_prompt.return_value = "test prompt"
mock_task_manager.get_stop_tokens.return_value = []
mock_task_manager.get_max_tokens.return_value = 3
- mock_task_manager.parse_task_output.return_value = mock_parsed_result
+ mock_task_manager.parse_task_output.return_value = parsed_result
return llms, mock_task_manager
diff --git a/tests/test_filters.py b/tests/test_filters.py
index f97b288d2..85dd4309d 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -19,9 +19,6 @@
import pytest
from nemoguardrails.llm.filters import (
- ReasoningExtractionResult,
- extract_and_strip_trace,
- find_reasoning_tokens_position,
first_turns,
last_turns,
to_chat_messages,
@@ -95,230 +92,6 @@ def test_last_turns():
assert last_turns(colang_history, 2) == colang_history
-def _build_test_string(parts: List[Union[str, Tuple[str, int]]]) -> str:
- """Builds a test string from a list of parts.
-
- Each part can be a literal string or a (character, count) tuple.
- Example: [("a", 3), "[START]", ("b", 5)] -> "aaa[START]bbbbb"
- """
- result = []
- for part in parts:
- if isinstance(part, str):
- result.append(part)
- elif isinstance(part, tuple) and len(part) == 2:
- char, count = part
- result.append(char * count)
- else:
- raise TypeError(f"Invalid part type in _build_test_string: {part}")
- return "".join(result)
-
-
-@pytest.mark.parametrize(
- "response, start_token, end_token, expected",
- [
- (
- _build_test_string(
- [
- ("a", 5),
- "[START]",
- ("b", 10),
- "[END]",
- ("c", 5),
- ]
- ),
- "[START]",
- "[END]",
- (5, 22), # 5 a's + 7 START + 10 b = 22
- ),
- # multiple reasoning sections
- (
- _build_test_string(
- [
- ("a", 3),
- "[START]",
- ("b", 4),
- "[END]",
- ("c", 3),
- "[START]",
- ("d", 4),
- "[END]",
- ("e", 3),
- ]
- ),
- "[START]",
- "[END]",
- (
- 3,
- 33,
- ),
- ),
- (
- _build_test_string(
- [
- ("a", 2),
- "[START]",
- ("b", 2),
- "[START]",
- ("c", 2),
- "[END]",
- ("d", 2),
- "[END]",
- ("e", 2),
- ]
- ),
- "[START]",
- "[END]",
- (
- 2,
- 27,
- ),
- ),
- (
- _build_test_string([("a", 10)]),
- "[START]",
- "[END]",
- (0, -1), # no tokens found, start_index is 0
- ),
- (
- _build_test_string(
- [
- ("a", 5),
- "[START]",
- ("b", 5),
- ]
- ),
- "[START]",
- "[END]",
- (5, -1), # [START] at pos 5, no end token
- ),
- (
- _build_test_string(
- [
- ("a", 5),
- "[END]",
- ("b", 5),
- ]
- ),
- "[START]",
- "[END]",
- (0, 5), # no start token so 0, end at pos 5
- ),
- (
- "",
- "[START]",
- "[END]",
- (0, -1), # empty string, start_index is 0
- ),
- ],
-)
-def test_find_token_positions_for_removal(response, start_token, end_token, expected):
- """Test finding token positions for removal.
-
- Test cases use _build_test_string for clarity and mathematical obviousness.
- """
- assert find_reasoning_tokens_position(response, start_token, end_token) == expected
-
-
-@pytest.mark.parametrize(
- "response, start_token, end_token, expected_text, expected_trace",
- [
- (
- "This is an example [START]hidden reasoning[END] of a response.",
- "[START]",
- "[END]",
- "This is an example of a response.",
- "[START]hidden reasoning[END]",
- ),
- (
- "Before [START]first[END] middle [START]second[END] after.",
- "[START]",
- "[END]",
- "Before after.",
- "[START]first[END] middle [START]second[END]",
- ),
- (
- "Text [START] first [START] nested [END] second [END] more text.",
- "[START]",
- "[END]",
- "Text more text.",
- "[START] first [START] nested [END] second [END]",
- ),
- (
- "No tokens here",
- "[START]",
- "[END]",
- "No tokens here",
- None,
- ),
- (
- "Only [START] start token",
- "[START]",
- "[END]",
- "Only [START] start token",
- None,
- ),
- (
- "Only end token [END]",
- "[START]",
- "[END]",
- "",
- "Only end token [END]",
- ),
- (
- "",
- "[START]",
- "[END]",
- "",
- None,
- ),
- # End token before start token (tests the final return path)
- (
- "some [END] text [START]",
- "[START]",
- "[END]",
- "some [END] text [START]",
- None,
- ),
- # Original test cases adapted
- (
- "[END] Out of order [START] tokens [END] example.",
- "[START]",
- "[END]",
- "[END] Out of order example.",
- "[START] tokens [END]",
- ),
- (
- "[START] nested [START] tokens [END] out of [END] order.",
- "[START]",
- "[END]",
- " order.",
- "[START] nested [START] tokens [END] out of [END]",
- ),
- (
- "[END] [START] [START] example [END] text.",
- "[START]",
- "[END]",
- "[END] text.",
- "[START] [START] example [END]",
- ),
- (
- "example text.",
- "[START]",
- "[END]",
- "example text.",
- None,
- ),
- ],
-)
-def test_extract_and_strip_trace(
- response, start_token, end_token, expected_text, expected_trace
-):
- """Tests the extraction and stripping of reasoning traces."""
- result = extract_and_strip_trace(response, start_token, end_token)
- assert result.text == expected_text
- assert result.reasoning_trace == expected_trace
-
-
class TestToChatMessages:
def test_to_chat_messages_with_text_only(self):
"""Test to_chat_messages with text-only messages."""
diff --git a/tests/test_llm_task_manager.py b/tests/test_llm_task_manager.py
index 7897e55b6..e8e00550e 100644
--- a/tests/test_llm_task_manager.py
+++ b/tests/test_llm_task_manager.py
@@ -298,167 +298,6 @@ def test_stop_configuration_parameter():
assert stop_token in task_prompt.stop
-def test_preprocess_events_removes_reasoning_traces():
- """Test that reasoning traces are removed from bot messages in rendered prompts."""
- config = RailsConfig.from_content(
- yaml_content=textwrap.dedent(
- """
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- start_token: ""
- end_token: ""
- rails:
- output:
- apply_to_reasoning_traces: true
- prompts:
- - task: generate_user_intent
- content: |-
- {% if examples %}{{ examples }}{% endif %}
- {{ history | colang }}
- user "{{ user_input }}"
- user intent:
- """
- )
- )
-
- llm_task_manager = LLMTaskManager(config)
-
- events = [
- {"type": "UtteranceUserActionFinished", "final_transcript": "Hello"},
- {
- "type": "StartUtteranceBotAction",
- "script": "Let me think how to respond some crazy COTHi there!",
- },
- {"type": "UtteranceUserActionFinished", "final_transcript": "How are you?"},
- ]
-
- rendered_prompt = llm_task_manager.render_task_prompt(
- task=Task.GENERATE_USER_INTENT,
- context={"user_input": "How are you?", "examples": ""},
- events=events,
- )
-
- assert isinstance(rendered_prompt, str)
-
- assert "" not in rendered_prompt
- assert "" not in rendered_prompt
- assert "Let me think how to respond..." not in rendered_prompt
-
- assert "Hi there!" in rendered_prompt
-
-
-def test_preprocess_events_preserves_original_events():
- """Test that _preprocess_events_for_prompt doesn't modify the original events."""
- config = RailsConfig.from_content(
- yaml_content=textwrap.dedent(
- """
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- start_token: ""
- end_token: ""
- rails:
- output:
- apply_to_reasoning_traces: true
- """
- )
- )
-
- llm_task_manager = LLMTaskManager(config)
-
- original_events = [
- {"type": "UtteranceUserActionFinished", "final_transcript": "Hello"},
- {
- "type": "StartUtteranceBotAction",
- "script": "Let me think how to respond some crazy COTHi there!",
- },
- {"type": "UtteranceUserActionFinished", "final_transcript": "How are you?"},
- ]
-
- events_copy = copy.deepcopy(original_events)
-
- processed_events = llm_task_manager._preprocess_events_for_prompt(events_copy)
-
- assert events_copy == original_events
-
- assert "" not in processed_events[1]["script"]
- assert "" not in processed_events[1]["script"]
- assert processed_events[1]["script"] == "Hi there!"
-
-
-def test_reasoning_traces_not_included_in_prompt_history():
- """Test that reasoning traces don't get included in prompt history for subsequent LLM calls."""
- config = RailsConfig.from_content(
- yaml_content=textwrap.dedent(
- """
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- start_token: ""
- end_token: ""
- rails:
- output:
- apply_to_reasoning_traces: true
- prompts:
- - task: generate_user_intent
- content: |-
- {% if examples %}{{ examples }}{% endif %}
- Previous conversation:
- {{ history | colang }}
-
- Current user message:
- user "{{ user_input }}"
- user intent:
- """
- )
- )
-
- llm_task_manager = LLMTaskManager(config)
-
- events = [
- {"type": "UtteranceUserActionFinished", "final_transcript": "Hello"},
- {
- "type": "StartUtteranceBotAction",
- "script": "I should greet the user back.Hi there!",
- },
- {
- "type": "UtteranceUserActionFinished",
- "final_transcript": "What's the weather like?",
- },
- {
- "type": "StartUtteranceBotAction",
- "script": "I should explain I don't have real-time weather data.I don't have access to real-time weather information.",
- },
- {"type": "UtteranceUserActionFinished", "final_transcript": "Tell me about AI"},
- ]
-
- rendered_prompt = llm_task_manager.render_task_prompt(
- task=Task.GENERATE_USER_INTENT,
- context={"user_input": "Tell me about AI", "examples": ""},
- events=events,
- )
-
- assert isinstance(rendered_prompt, str)
-
- assert "I should greet the user back." not in rendered_prompt
- assert (
- "I should explain I don't have real-time weather data."
- not in rendered_prompt
- )
-
- assert (
- "Hi there!" in rendered_prompt
- or "I don't have access to real-time weather information." in rendered_prompt
- )
-
-
def test_get_task_model_with_empty_models():
"""Test that get_task_model returns None when models list is empty.
diff --git a/tests/test_llmrails_reasoning.py b/tests/test_llmrails_reasoning.py
deleted file mode 100644
index 27ab6d911..000000000
--- a/tests/test_llmrails_reasoning.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Optional
-
-import pytest
-
-from nemoguardrails import LLMRails, RailsConfig
-from tests.utils import FakeLLM
-
-
-@pytest.fixture
-def rails_config():
- return RailsConfig.parse_object(
- {
- "models": [
- {
- "type": "main",
- "engine": "fake",
- "model": "fake",
- "reasoning_config": {
- "start_token": "",
- "end_token": "",
- },
- }
- ],
- "user_messages": {
- "express greeting": ["Hello!"],
- "ask math question": ["What is 2 + 2?", "5 + 9"],
- },
- "flows": [
- {
- "elements": [
- {"user": "express greeting"},
- {"bot": "express greeting"},
- ]
- },
- {
- "elements": [
- {"user": "ask math question"},
- {"execute": "compute"},
- {"bot": "provide math response"},
- {"bot": "ask if user happy"},
- ]
- },
- ],
- "bot_messages": {
- "express greeting": ["Hello! How are you?"],
- "provide response": ["The answer is 234", "The answer is 1412"],
- },
- }
- )
-
-
-@pytest.mark.asyncio
-async def test_1(rails_config):
- llm = FakeLLM(
- responses=[
- "some redundant CoT text 1\n express greeting",
- "some redundant CoT text 2\n ask math question",
- 'some redundant CoT text 3\n "The answer is 5"',
- 'some important COT text\n "Are you happy with the result?"',
- ]
- )
-
- async def compute(what: Optional[str] = "2 + 3"):
- return eval(what)
-
- llm_rails = LLMRails(config=rails_config, llm=llm)
- llm_rails.runtime.register_action(compute)
-
- messages = [{"role": "user", "content": "Hello!"}]
- bot_message = await llm_rails.generate_async(messages=messages)
-
- assert bot_message == {"role": "assistant", "content": "Hello! How are you?"}
- messages.append(bot_message)
-
- messages.append({"role": "user", "content": "2 + 3"})
- bot_message = await llm_rails.generate_async(messages=messages)
- assert bot_message == {
- "role": "assistant",
- "content": "some important COT textThe answer is 5\nAre you happy with the result?",
- }
diff --git a/tests/test_llmrails_reasoning_output_rails.py b/tests/test_llmrails_reasoning_output_rails.py
deleted file mode 100644
index e02b507b4..000000000
--- a/tests/test_llmrails_reasoning_output_rails.py
+++ /dev/null
@@ -1,312 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for LLM Rails reasoning output configuration and behavior.
-
-This module contains tests that verify the behavior of LLM Rails when handling
-reasoning traces in the output, including configuration options and guardrail
-behavior.
-"""
-
-from typing import Any, Dict, NamedTuple
-
-import pytest
-
-from nemoguardrails import RailsConfig
-from tests.utils import TestChat
-
-
-class ReasoningTraceTestCase(NamedTuple):
- """Test case for reasoning trace configuration.
-
- Attributes:
- description: description of the test case
- remove_reasoning_traces: Whether to remove reasoning traces in the model config
- apply_to_reasoning_traces: Whether to apply output rails to reasoning traces
- expected_think_tag: Whether the think tag should be present in the response
- expected_error_message: Whether the error message should be present in the response
- """
-
- description: str
- remove_reasoning_traces: bool
- apply_to_reasoning_traces: bool
- expected_think_tag: bool
- expected_error_message: bool
-
-
-async def check_sensitive_info(context: Dict[str, Any]) -> bool:
- """Check if the response contains sensitive information."""
- response = context.get("bot_message", "")
- prompt = context.get("user_message", "")
- input_text = response or prompt
- return "credit card" in input_text.lower() or any(
- c.isdigit() for c in input_text if c.isdigit() or c == "-"
- )
-
-
-async def check_think_tag_present(context: Dict[str, Any]) -> bool:
- """Check if the think tag is present in the bot's response."""
- response = context.get("bot_message", "")
- return "" in response
-
-
-@pytest.fixture
-def base_config() -> RailsConfig:
- """Creates a base RailsConfig with common test configuration."""
- return RailsConfig.from_content(
- colang_content="""
- define flow check think tag
- $not_allowed = execute check_think_tag_present
- if $not_allowed
- bot informs tag not allowed
- stop
-
- define bot informs tag not allowed
- "think tag is not allowed it must be removed"
- """,
- yaml_content="""
- models:
- - type: main
- engine: fake
- model: fake
- colang_version: "1.0"
- rails:
- output:
- flows:
- - check think tag
- """,
- )
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize(
- "test_case",
- [
- ReasoningTraceTestCase(
- description="Remove reasoning traces and show error when guardrail is enabled",
- remove_reasoning_traces=True,
- apply_to_reasoning_traces=True,
- expected_think_tag=True,
- expected_error_message=True,
- ),
- ReasoningTraceTestCase(
- description="Preserve reasoning traces and hide error when guardrail is disabled",
- remove_reasoning_traces=True,
- apply_to_reasoning_traces=False,
- expected_think_tag=True,
- expected_error_message=False,
- ),
- ReasoningTraceTestCase(
- description="Preserve reasoning traces and show error when guardrail is enabled",
- remove_reasoning_traces=False,
- apply_to_reasoning_traces=True,
- expected_think_tag=True,
- expected_error_message=True,
- ),
- ReasoningTraceTestCase(
- description="Remove reasoning traces and show error when both flags are disabled",
- remove_reasoning_traces=False,
- apply_to_reasoning_traces=False,
- expected_think_tag=True,
- expected_error_message=True,
- ),
- ],
- ids=lambda tc: tc.description,
-)
-async def test_output_rails_reasoning_traces_configuration(
- base_config: RailsConfig,
- test_case: ReasoningTraceTestCase,
-) -> None:
- """Test output rails with different reasoning traces configurations.
-
- The test verifies the following behaviors based on configuration:
-
- 1. When remove_reasoning_traces=True:
- - The model is configured to remove reasoning traces
- - However, the actual removal depends on apply_to_reasoning_traces
-
- 2. When apply_to_reasoning_traces=True:
- - The output rail will check for and report think tags
- - Because we expect the think tag to be present as output rails explicitly requires it
-
- 3. When apply_to_reasoning_traces=False:
- - The output rails will check for think tags
- - No error message will be shown because it is not there to get blocked
-
- """
- base_config.models[
- 0
- ].reasoning_config.remove_reasoning_traces = test_case.remove_reasoning_traces
- base_config.rails.output.apply_to_reasoning_traces = (
- test_case.apply_to_reasoning_traces
- )
-
- chat = TestChat(
- base_config,
- llm_completions=[
- " I should think more Your kindness is appreciated"
- ],
- )
-
- chat.app.runtime.register_action(check_think_tag_present)
-
- messages = [{"role": "user", "content": "you are nice"}]
- response = await chat.app.generate_async(messages=messages)
-
- if test_case.expected_think_tag:
- assert (
- "" in response["content"]
- ), "Think tag should be present in response"
- else:
- assert (
- "" not in response["content"]
- ), "Think tag should not be present in response"
-
- if test_case.expected_error_message:
- assert (
- "think tag is not allowed" in response["content"]
- ), "Error message should be present"
- else:
- assert (
- "think tag is not allowed" not in response["content"]
- ), "Error message should not be present"
-
-
-@pytest.mark.asyncio
-async def test_output_rails_preserves_reasoning_traces() -> None:
- """Test that output rails preserve reasoning traces when configured to do so."""
- config = RailsConfig.from_content(
- colang_content="""
- define flow check sensitive info
- $not_allowed = execute check_sensitive_info
- if $not_allowed
- bot provide sanitized response
- stop
- define bot provide sanitized response
- "I cannot share sensitive information."
- """,
- yaml_content="""
- models:
- - type: main
- engine: fake
- model: fake
- reasoning_config:
- remove_reasoning_traces: True
- colang_version: "1.0"
- rails:
- output:
- flows:
- - check sensitive info
- apply_to_reasoning_traces: True
- """,
- )
-
- chat = TestChat(
- config,
- llm_completions=[
- ' I should not share sensitive info \n "Here is my credit card: 1234-5678-9012-3456"',
- ],
- )
-
- chat.app.runtime.register_action(check_sensitive_info)
-
- messages = [{"role": "user", "content": "What's your credit card number?"}]
- response = await chat.app.generate_async(messages=messages)
-
- assert "" in response["content"], "Reasoning traces should be preserved"
- assert (
- "I should not share sensitive info" in response["content"]
- ), "Reasoning content should be preserved"
- assert (
- "credit card" not in response["content"].lower()
- ), "Sensitive information should be removed"
-
-
-@pytest.mark.asyncio
-async def test_output_rails_without_reasoning_traces() -> None:
- """Test that output rails properly handle responses when reasoning traces are disabled."""
- config = RailsConfig.from_content(
- colang_content="""
- define flow check sensitive info
- $not_allowed = execute check_sensitive_info
- if $not_allowed
- bot provide sanitized response
- stop
- define flow check think tag
- $not_allowed = execute check_think_tag_present
- if $not_allowed
- bot says tag not allowed
- stop
-
- define bot says tag not allowed
- " tag is not allowed it must be removed"
-
- define bot provide sanitized response
- "I cannot share sensitive information."
- """,
- yaml_content="""
- models:
- - type: main
- engine: fake
- model: fake
- reasoning_config:
- remove_reasoning_traces: True
- colang_version: "1.0"
- rails:
- input:
- flows:
- - check sensitive info
- output:
- flows:
- - check sensitive info
- - check think tag
- apply_to_reasoning_traces: false
- """,
- )
-
- chat = TestChat(
- config,
- llm_completions=[
- " I should think more Your credit card number is 1234-5678-9012-3456",
- ],
- )
-
- chat.app.runtime.register_action(check_sensitive_info)
- chat.app.runtime.register_action(check_think_tag_present)
-
- # case 1: Sensitive information is blocked by input rail
- messages = [{"role": "user", "content": "What's your credit card number?"}]
- response = await chat.app.generate_async(messages=messages)
-
- assert "" not in response["content"], "Think tag should not be present"
- assert (
- "I should not share sensitive info" not in response["content"]
- ), "Reasoning content should not be present"
- assert (
- response["content"] == "I cannot share sensitive information."
- ), "Should return sanitized response"
-
- # case 2: Think tag is preserved but content is sanitized
- messages = [{"role": "user", "content": "Tell me some numbers"}]
- response = await chat.app.generate_async(messages=messages)
-
- assert "" in response["content"], "Think tag should be present"
- assert (
- "I should not share sensitive info" not in response["content"]
- ), "Reasoning content should not be present"
- assert (
- response["content"]
- == " I should think more I cannot share sensitive information."
- ), "Should preserve think tag but sanitize content"
diff --git a/tests/test_reasoning_trace_context.py b/tests/test_reasoning_trace_context.py
deleted file mode 100644
index d1c0c6db3..000000000
--- a/tests/test_reasoning_trace_context.py
+++ /dev/null
@@ -1,227 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import pytest
-
-from nemoguardrails import RailsConfig
-from nemoguardrails.actions.llm.utils import get_and_clear_reasoning_trace_contextvar
-from nemoguardrails.context import reasoning_trace_var
-from nemoguardrails.rails.llm.llmrails import GenerationOptions, GenerationResponse
-from tests.utils import TestChat
-
-
-def test_get_and_clear_reasoning_trace_contextvar():
- """Test that it correctly gets and clears the trace."""
- reasoning_trace_var.set(" oh COT again ")
-
- result = get_and_clear_reasoning_trace_contextvar()
-
- assert result == " oh COT again "
- assert reasoning_trace_var.get() is None
-
-
-def test_get_and_clear_reasoning_trace_contextvar_empty():
- """Test that it returns None when no trace exists."""
- reasoning_trace_var.set(None)
-
- result = get_and_clear_reasoning_trace_contextvar()
-
- assert result is None
-
-
-@pytest.mark.asyncio
-async def test_generate_async_trace_with_messages_and_options():
- """Test generate_async prepends reasoning trace when using generation options and messages."""
- config = RailsConfig.from_content(
- colang_content="""
- define user express greeting
- "hi"
- "hello"
-
- define bot express greeting
- "Hello! How can I assist you today?"
-
- define flow
- user express greeting
- bot express greeting
- """,
- yaml_content="""
- models: []
- rails:
- output:
- apply_to_reasoning_traces: true
- """,
- )
-
- chat = TestChat(
- config,
- llm_completions=[
- "user express greeting",
- "bot express greeting",
- "Hello! How can I assist you today?",
- ],
- )
-
- reasoning_trace_var.set(" yet another COT ")
-
- options = GenerationOptions()
- result = await chat.app.generate_async(
- messages=[{"role": "user", "content": "hi"}], options=options
- )
-
- assert isinstance(result, GenerationResponse)
- assert isinstance(result.response, list)
- assert len(result.response) == 1
- assert (
- result.response[0]["content"]
- == " yet another COT Hello! How can I assist you today?"
- )
- assert reasoning_trace_var.get() is None
-
-
-@pytest.mark.asyncio
-async def test_generate_async_trace_with_prompt_and_options():
- """Test generate_async prepends reasoning trace using prompt and options"""
- config = RailsConfig.from_content(
- colang_content="""
- define user express greeting
- "hi"
- "hello"
-
- define bot express greeting
- "Hello! How can I assist you today?"
-
- define flow
- user express greeting
- bot express greeting
- """,
- yaml_content="""
- models: []
- rails:
- output:
- apply_to_reasoning_traces: true
- """,
- )
-
- chat = TestChat(
- config,
- llm_completions=[
- "user express greeting",
- "bot express greeting",
- "Hello! How can I assist you today?",
- ],
- )
-
- reasoning_trace_var.set(" yet another COT ")
-
- options = GenerationOptions()
- result = await chat.app.generate_async(options=options, prompt="test prompt")
-
- assert isinstance(result, GenerationResponse)
- assert isinstance(result.response, str)
- assert (
- result.response
- == " yet another COT Hello! How can I assist you today?"
- )
- assert reasoning_trace_var.get() is None
-
-
-@pytest.mark.asyncio
-async def test_generate_async_trace_messages_only():
- """Test generate_async prepends reasoning trace when using only messages."""
- config = RailsConfig.from_content(
- colang_content="""
- define user express greeting
- "hi"
- "hello"
-
- define bot express greeting
- "Hello! How can I assist you today?"
-
- define flow
- user express greeting
- bot express greeting
- """,
- yaml_content="""
- models: []
- rails:
- output:
- apply_to_reasoning_traces: true
- """,
- )
-
- chat = TestChat(
- config,
- llm_completions=[
- "user express greeting",
- "bot express greeting",
- "Hello! How can I assist you today?",
- ],
- )
-
- reasoning_trace_var.set(" yet another COT ")
-
- result = await chat.app.generate_async(messages=[{"role": "user", "content": "hi"}])
-
- assert isinstance(result, dict)
- assert result.get("role") == "assistant"
- assert (
- result.get("content")
- == " yet another COT Hello! How can I assist you today?"
- )
- assert reasoning_trace_var.get() is None
-
-
-@pytest.mark.asyncio
-async def test_generate_async_trace_with_prompt_only():
- """Test generate_async prepends reasoning trace when using prompt."""
- config = RailsConfig.from_content(
- colang_content="""
- define user express greeting
- "hi"
- "hello"
-
- define bot express greeting
- "Hello! How can I assist you today?"
-
- define flow
- user express greeting
- bot express greeting
- """,
- yaml_content="""
- models: []
- rails:
- output:
- apply_to_reasoning_traces: true
- """,
- )
-
- chat = TestChat(
- config,
- llm_completions=[
- "user express greeting",
- "bot express greeting",
- "Hello! How can I assist you today?",
- ],
- )
-
- reasoning_trace_var.set(" yet another COT ")
-
- result = await chat.app.generate_async(prompt="hi")
-
- assert (
- result == " yet another COT Hello! How can I assist you today?"
- )
- assert reasoning_trace_var.get() is None
diff --git a/tests/test_reasoning_traces.py b/tests/test_reasoning_traces.py
deleted file mode 100644
index c603a04b6..000000000
--- a/tests/test_reasoning_traces.py
+++ /dev/null
@@ -1,451 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import warnings
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-
-from nemoguardrails.actions.llm.generation import (
- LLMGenerationActions,
- _get_apply_to_reasoning_traces,
- _process_parsed_output,
-)
-from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx
-from nemoguardrails.context import (
- generation_options_var,
- llm_call_info_var,
- streaming_handler_var,
-)
-from nemoguardrails.llm.filters import extract_and_strip_trace
-from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput
-from nemoguardrails.llm.types import Task
-from nemoguardrails.logging.explain import LLMCallInfo
-from nemoguardrails.rails.llm.config import Model, RailsConfig, ReasoningModelConfig
-
-
-def create_mock_config():
- config = MagicMock(spec=RailsConfig)
- config.rails = MagicMock()
- config.rails.output = MagicMock()
- config.rails.output.apply_to_reasoning_traces = False
- return config
-
-
-class TestReasoningTraces:
- """Test the reasoning traces functionality."""
-
- def test_remove_reasoning_traces_basic(self):
- """Test basic removal of reasoning traces."""
- input_text = "This is a \nSome reasoning here\nMore reasoning\n response."
- expected = "This is a response."
- result = extract_and_strip_trace(input_text, "", "")
- assert result.text == expected
-
- def test_remove_reasoning_traces_multiline(self):
- """Test removal of multiline reasoning traces."""
- input_text = """
- Here is my
- I need to think about this...
- Step 1: Consider the problem
- Step 2: Analyze possibilities
- This makes sense
- response after thinking.
- """
- expected = "\n Here is my response after thinking.\n "
- result = extract_and_strip_trace(input_text, "", "")
- assert result.text == expected
-
- def test_remove_reasoning_traces_multiple_sections(self):
- """Test removal of multiple reasoning trace sections."""
- input_text = "Start Reasoning 1 middle Reasoning 2 end."
- # Note: The current implementation removes all content between the first start and last end token
- # So the expected result is "Start end." not "Start middle end."
- expected = "Start end."
- result = extract_and_strip_trace(input_text, "", "")
- assert result.text == expected
-
- def test_remove_reasoning_traces_nested(self):
- """Test handling of nested reasoning trace markers (should be handled correctly)."""
- input_text = (
- "Begin Outer Inner Outer End."
- )
- expected = "Begin End."
- result = extract_and_strip_trace(input_text, "", "")
- assert result.text == expected
-
- def test_remove_reasoning_traces_unmatched(self):
- """Test handling of unmatched reasoning trace markers."""
- input_text = "Begin Unmatched end."
- result = extract_and_strip_trace(input_text, "", "")
- # We ~hould keep the unmatched tag since it's not a complete section
- assert result.text == "Begin Unmatched end."
-
- @pytest.mark.asyncio
- async def test_task_manager_parse_task_output(self):
- """Test that the task manager correctly removes reasoning traces."""
- # mock config
- config = create_mock_config()
- # Create a ReasoningModelConfig
- reasoning_config = ReasoningModelConfig(
- remove_reasoning_traces=True,
- start_token="",
- end_token="",
- )
-
- # Create a Model with the reasoning_config
- model_config = Model(
- type="main",
- engine="test",
- model="test-model",
- reasoning_config=reasoning_config,
- )
-
- # mock the get_prompt and get_task_model functions
- with (
- patch("nemoguardrails.llm.taskmanager.get_prompt") as mock_get_prompt,
- patch(
- "nemoguardrails.llm.taskmanager.get_task_model"
- ) as mock_get_task_model,
- ):
- # Configure the mocks
- mock_get_prompt.return_value = MagicMock(output_parser=None)
- mock_get_task_model.return_value = model_config
-
- llm_task_manager = LLMTaskManager(config)
-
- # test parsing with reasoning traces
- input_text = (
- "This is a Some reasoning here final answer."
- )
- expected = "This is a final answer."
-
- result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
- assert result.text == expected
-
- @pytest.mark.asyncio
- async def test_parse_task_output_without_reasoning_config(self):
- """Test that parse_task_output works without a reasoning config."""
-
- config = create_mock_config()
-
- # a Model without reasoning_config
- model_config = Model(type="main", engine="test", model="test-model")
-
- # Mock the get_prompt and get_task_model functions
- with (
- patch("nemoguardrails.llm.taskmanager.get_prompt") as mock_get_prompt,
- patch(
- "nemoguardrails.llm.taskmanager.get_task_model"
- ) as mock_get_task_model,
- ):
- mock_get_prompt.return_value = MagicMock(output_parser=None)
- mock_get_task_model.return_value = model_config
-
- llm_task_manager = LLMTaskManager(config)
-
- # test parsing without a reasoning config
- input_text = (
- "This is a Some reasoning here final answer."
- )
- result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
- assert result.text == input_text
-
- @pytest.mark.asyncio
- async def test_parse_task_output_with_default_reasoning_traces(self):
- """Test that parse_task_output works with default reasoning traces."""
-
- config = create_mock_config()
-
- # Create a Model with default reasoning_config
- model_config = Model(
- type="main",
- engine="test",
- model="test-model",
- reasoning_config=ReasoningModelConfig(),
- )
-
- # Mock the get_prompt and get_task_model functions
- with (
- patch("nemoguardrails.llm.taskmanager.get_prompt") as mock_get_prompt,
- patch(
- "nemoguardrails.llm.taskmanager.get_task_model"
- ) as mock_get_task_model,
- ):
- mock_get_prompt.return_value = MagicMock(output_parser=None)
- mock_get_task_model.return_value = model_config
-
- llm_task_manager = LLMTaskManager(config)
-
- # test parsing with default reasoning traces
- input_text = "This is a Some reasoning here final answer."
- result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
- assert result.text == "This is a final answer."
-
- @pytest.mark.asyncio
- async def test_parse_task_output_with_output_parser(self):
- """Test that parse_task_output works with an output parser."""
-
- config = create_mock_config()
-
- # Create a Model with reasoning_config
- model_config = Model(
- type="main",
- engine="test",
- model="test-model",
- reasoning_config=ReasoningModelConfig(
- remove_reasoning_traces=True,
- start_token="",
- end_token="",
- ),
- )
-
- def mock_parser(text):
- return f"PARSED: {text}"
-
- # Mock the get_prompt and get_task_model functions
- with (
- patch("nemoguardrails.llm.taskmanager.get_prompt") as mock_get_prompt,
- patch(
- "nemoguardrails.llm.taskmanager.get_task_model"
- ) as mock_get_task_model,
- ):
- mock_get_prompt.return_value = MagicMock(output_parser="mock_parser")
- mock_get_task_model.return_value = model_config
-
- llm_task_manager = LLMTaskManager(config)
- llm_task_manager.output_parsers["mock_parser"] = mock_parser
-
- # test parsing with an output parser
- input_text = (
- "This is a Some reasoning here final answer."
- )
- result = llm_task_manager.parse_task_output(Task.GENERAL, input_text)
- assert result.text == "PARSED: This is a final answer."
-
- @pytest.mark.asyncio
- async def test_passthrough_llm_action_removes_reasoning(self):
- """Test that passthrough_llm_action correctly removes reasoning traces."""
- # mock the necessary components with proper nested structure
- config = MagicMock()
- # set required properties on the moc
- config.user_messages = {}
- config.bot_messages = {}
- config.config_path = "mock_path"
- config.flows = []
-
- # nested mock structure for rails
- rails_mock = MagicMock()
- dialog_mock = MagicMock()
- user_messages_mock = MagicMock()
- user_messages_mock.embeddings_only = False
- dialog_mock.user_messages = user_messages_mock
- rails_mock.dialog = dialog_mock
- config.rails = rails_mock
-
- llm = AsyncMock()
- llm_task_manager = MagicMock(spec=LLMTaskManager)
-
- # set up the mocked LLM to return text with reasoning traces
- llm.return_value = (
- "This is a Some reasoning here final answer."
- )
-
- # set up the mock llm_task_manager to properly process the output
- llm_task_manager.parse_task_output.return_value = "This is a final answer."
-
- # mock init method to avoid async initialization
- with patch.object(
- LLMGenerationActionsV2dotx, "init", AsyncMock(return_value=None)
- ):
- # create LLMGenerationActionsV2dotx with our mocks
- action_generator = LLMGenerationActionsV2dotx(
- config=config,
- llm=llm,
- llm_task_manager=llm_task_manager,
- get_embedding_search_provider_instance=MagicMock(),
- verbose=False,
- )
-
- # set context variables
- llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
- streaming_handler_var.set(None)
- generation_options_var.set(None)
-
- # mock the function directly to test the parse_task_output call
- action_generator.parse_task_output = llm_task_manager.parse_task_output
-
- # instead of calling passthrough_llm_action, let's directly test the functionality we want
- # by mocking what it does
- result = await llm(user_message="Test message")
- result = llm_task_manager.parse_task_output(Task.GENERAL, output=result)
-
- llm.assert_called_once()
-
- llm_task_manager.parse_task_output.assert_called_once_with(
- Task.GENERAL, output=llm.return_value
- )
-
- # verify the result has reasoning traces removed
- assert result == "This is a final answer."
-
- @pytest.mark.asyncio
- async def test_generate_bot_message_passthrough_removes_reasoning(self):
- """Test that generate_bot_message in passthrough mode correctly removes reasoning traces."""
- config = MagicMock()
- config.passthrough = True
-
- config.user_messages = {}
- config.bot_messages = {}
-
- rails_mock = MagicMock()
- output_mock = MagicMock()
- output_mock.flows = []
- rails_mock.output = output_mock
- config.rails = rails_mock
-
- llm = AsyncMock()
- llm_task_manager = MagicMock(spec=LLMTaskManager)
-
- # set up the mocked LLM to return text with reasoning traces
- llm.return_value = (
- "This is a Some reasoning here final answer."
- )
-
- llm_task_manager.parse_task_output.return_value = "This is a final answer."
-
- with patch.object(LLMGenerationActions, "init", AsyncMock(return_value=None)):
- # create LLMGenerationActions with our mocks
- action_generator = LLMGenerationActions(
- config=config,
- llm=llm,
- llm_task_manager=llm_task_manager,
- get_embedding_search_provider_instance=MagicMock(),
- verbose=False,
- )
-
- # create a mock bot intent event
- events = [
- {"type": "BotIntent", "intent": "respond"},
- {"type": "StartInternalSystemAction"},
- ]
-
- # set up context variables
- llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_BOT_MESSAGE.value))
- streaming_handler_var.set(None)
- generation_options_var.set(None)
-
- # mock the context vars
- context = {"user_message": "Hello"}
-
- # instead of calling generate_bot_message, let's directly test the functionality we want
- # by mocking what it does with passthrough mode
- result = await llm("Test message")
- result = llm_task_manager.parse_task_output(Task.GENERAL, output=result)
-
- # creating a simulated result object similar to what generate_bot_message would return
- class ActionResult:
- def __init__(self, events):
- self.events = events
-
- # creating a mock result with the parsed text
- mock_result = ActionResult(events=[{"type": "BotMessage", "text": result}])
-
- llm.assert_called_once()
-
- llm_task_manager.parse_task_output.assert_called_once_with(
- Task.GENERAL, output=llm.return_value
- )
-
- assert mock_result.events[0]["text"] == "This is a final answer."
-
-
-class TestProcessParsedOutput:
- """Test the _process_parsed_output function."""
-
- def test_process_parsed_output_with_reasoning_trace(self):
- """Test processing output with reasoning trace when guardrail is enabled."""
- result = ParsedTaskOutput(
- text="final answer",
- reasoning_trace="some reasoning",
- )
- output = _process_parsed_output(result, include_reasoning_trace=True)
- assert output == "some reasoningfinal answer"
-
- def test_process_parsed_output_with_reasoning_trace_disabled(self):
- """Test processing output with reasoning trace when guardrail is disabled."""
- result = ParsedTaskOutput(
- text="final answer",
- reasoning_trace="some reasoning",
- )
- output = _process_parsed_output(result, include_reasoning_trace=False)
- assert output == "final answer"
-
- def test_process_parsed_output_without_reasoning_trace(self):
- """Test processing output without reasoning trace."""
- result = ParsedTaskOutput(text="final answer", reasoning_trace=None)
- output = _process_parsed_output(result, include_reasoning_trace=True)
- assert output == "final answer"
-
-
-class TestGuardrailReasoningTraces:
- """Test the guardrail reasoning traces configuration."""
-
- def test_get_apply_to_reasoning_traces_enabled(self):
- """Test getting guardrail reasoning traces when enabled."""
- config = create_mock_config()
- config.rails.output.apply_to_reasoning_traces = True
- assert _get_apply_to_reasoning_traces(config) is True
-
- def test_get_apply_to_reasoning_traces_disabled(self):
- """Test getting guardrail reasoning traces when disabled."""
- config = create_mock_config()
- config.rails.output.apply_to_reasoning_traces = False
- assert _get_apply_to_reasoning_traces(config) is False
-
- def test_deprecated_remove_thinking_traces(self):
- """Test that using remove_thinking_traces issues a deprecation warning."""
-
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
-
- config = RailsConfig.from_content(
- yaml_content="""
- models:
- - type: main
- engine: openai
- model: gpt-3.5-turbo-instruct
- reasoning_config:
- remove_thinking_traces: False
- """
- )
-
- assert config.models[0].reasoning_config.remove_reasoning_traces is False
-
- assert config.models[0].reasoning_config.remove_thinking_traces is False
-
- found_expected_warning = False
- for warning in w:
- if (
- issubclass(warning.category, DeprecationWarning)
- and "remove_thinking_traces" in str(warning.message)
- and "remove_reasoning_traces" in str(warning.message)
- ):
- found_expected_warning = True
- break
-
- assert (
- found_expected_warning
- ), "Expected DeprecationWarning for remove_thinking_traces was not issued."