Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions examples/configs/nemotron/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,3 @@ models:
- type: main
engine: nim
model: nvidia/llama-3.1-nemotron-ultra-253b-v1
reasoning_config:
remove_reasoning_traces: False # Set True to remove traces from the internal tasks
49 changes: 1 addition & 48 deletions nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
from nemoguardrails.kb.kb import KnowledgeBase
from nemoguardrails.llm.prompts import get_prompt
from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.llm.types import Task
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
Expand Down Expand Up @@ -496,7 +496,6 @@ async def generate_user_intent(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_USER_INTENT, output=result
)
result = result.text

user_intent = get_first_nonempty_line(result)
if user_intent is None:
Expand Down Expand Up @@ -594,10 +593,6 @@ async def generate_user_intent(
Task.GENERAL, output=text
)

text = _process_parsed_output(
text, self._include_reasoning_traces()
)

else:
# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
Expand Down Expand Up @@ -639,8 +634,6 @@ async def generate_user_intent(
text = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=result
)

text = _process_parsed_output(text, self._include_reasoning_traces())
text = text.strip()
if text.startswith('"'):
text = text[1:-1]
Expand Down Expand Up @@ -750,7 +743,6 @@ async def generate_next_step(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_NEXT_STEPS, output=result
)
result = result.text

# If we don't have multi-step generation enabled, we only look at the first line.
if not self.config.enable_multi_step_generation:
Expand Down Expand Up @@ -1036,10 +1028,6 @@ async def generate_bot_message(
Task.GENERAL, output=result
)

result = _process_parsed_output(
result, self._include_reasoning_traces()
)

log.info(
"--- :: LLM Bot Message Generation passthrough call took %.2f seconds",
time() - t0,
Expand Down Expand Up @@ -1111,10 +1099,6 @@ async def generate_bot_message(
Task.GENERATE_BOT_MESSAGE, output=result
)

result = _process_parsed_output(
result, self._include_reasoning_traces()
)

# TODO: catch openai.error.InvalidRequestError from exceeding max token length

result = get_multiline_response(result)
Expand Down Expand Up @@ -1212,7 +1196,6 @@ async def generate_value(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_VALUE, output=result
)
result = result.text

# We only use the first line for now
# TODO: support multi-line values?
Expand Down Expand Up @@ -1433,7 +1416,6 @@ async def generate_intent_steps_message(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_INTENT_STEPS_MESSAGE, output=result
)
result = result.text

# TODO: Implement logic for generating more complex Colang next steps (multi-step),
# not just a single bot intent.
Expand Down Expand Up @@ -1516,7 +1498,6 @@ async def generate_intent_steps_message(
result = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=result
)
result = _process_parsed_output(result, self._include_reasoning_traces())
text = result.strip()
if text.startswith('"'):
text = text[1:-1]
Expand All @@ -1529,10 +1510,6 @@ async def generate_intent_steps_message(
events=[new_event_dict("BotMessage", text=text)],
)

def _include_reasoning_traces(self) -> bool:
"""Get the configuration value for whether to include reasoning traces in output."""
return _get_apply_to_reasoning_traces(self.config)


def clean_utterance_content(utterance: str) -> str:
"""
Expand All @@ -1550,27 +1527,3 @@ def clean_utterance_content(utterance: str) -> str:
# It should be translated to an actual \n character.
utterance = utterance.replace("\\n", "\n")
return utterance


def _record_reasoning_trace(trace: str) -> None:
"""Store the reasoning trace in context for later retrieval."""
reasoning_trace_var.set(trace)


def _assemble_response(text: str, trace: Optional[str], include_reasoning: bool) -> str:
"""Combine trace and text if requested, otherwise just return text."""
return (trace + text) if (trace and include_reasoning) else text


def _process_parsed_output(
output: ParsedTaskOutput, include_reasoning_trace: bool
) -> str:
"""Record trace, then assemble the final LLM response."""
if reasoning_trace := output.reasoning_trace:
_record_reasoning_trace(reasoning_trace)
return _assemble_response(output.text, reasoning_trace, include_reasoning_trace)


def _get_apply_to_reasoning_traces(config: RailsConfig) -> bool:
"""Get the configuration value for whether to include reasoning traces in output."""
return config.rails.output.apply_to_reasoning_traces
15 changes: 14 additions & 1 deletion nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def _store_tool_calls(response) -> None:


def _store_response_metadata(response) -> None:
"""Store response metadata excluding content for metadata preservation."""
"""Store response metadata excluding content for metadata preservation.

Also extracts reasoning content from additional_kwargs if available from LangChain.
"""
if hasattr(response, "model_fields"):
metadata = {}
for field_name in response.model_fields:
Expand All @@ -188,6 +191,16 @@ def _store_response_metadata(response) -> None:
): # Exclude content since it may be modified by rails
metadata[field_name] = getattr(response, field_name)
llm_response_metadata_var.set(metadata)

if hasattr(response, "additional_kwargs"):
additional_kwargs = response.additional_kwargs
if (
isinstance(additional_kwargs, dict)
and "reasoning_content" in additional_kwargs
):
reasoning_content = additional_kwargs["reasoning_content"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 201-203 aren't covered by tests in this PR. These are pretty important since they're where we pull out the additional_kwargs and put them in the context var (!) Are there tests in another one of the stacked PRs that cover these?

if reasoning_content:
reasoning_trace_var.set(reasoning_content)
else:
llm_response_metadata_var.set(None)

Expand Down
14 changes: 0 additions & 14 deletions nemoguardrails/actions/v2_x/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,6 @@ async def generate_user_intent( # pyright: ignore (TODO - Signature completely
Task.GENERATE_USER_INTENT_FROM_USER_ACTION, output=result
)

result = result.text

user_intent = get_first_nonempty_line(result)
# GTP-4o often adds 'user intent: ' in front
if user_intent and ":" in user_intent:
Expand Down Expand Up @@ -401,8 +399,6 @@ async def generate_user_intent_and_bot_action(
Task.GENERATE_USER_INTENT_AND_BOT_ACTION_FROM_USER_ACTION, output=result
)

result = result.text

user_intent = get_first_nonempty_line(result)

if user_intent and ":" in user_intent:
Expand Down Expand Up @@ -578,8 +574,6 @@ async def generate_flow_from_instructions(
task=Task.GENERATE_FLOW_FROM_INSTRUCTIONS, output=result
)

result = result.text

# TODO: why this is not part of a filter or output_parser?
#
lines = _remove_leading_empty_lines(result).split("\n")
Expand Down Expand Up @@ -660,8 +654,6 @@ async def generate_flow_from_name(
task=Task.GENERATE_FLOW_FROM_NAME, output=result
)

result = result.text

lines = _remove_leading_empty_lines(result).split("\n")

if lines[0].startswith("flow"):
Expand Down Expand Up @@ -736,8 +728,6 @@ async def generate_flow_continuation(
task=Task.GENERATE_FLOW_CONTINUATION, output=result
)

result = result.text

lines = _remove_leading_empty_lines(result).split("\n")

if len(lines) == 0 or (len(lines) == 1 and lines[0] == ""):
Expand Down Expand Up @@ -869,8 +859,6 @@ async def generate_value( # pyright: ignore (TODO - different arguments to base
Task.GENERATE_VALUE_FROM_INSTRUCTION, output=result
)

result = result.text

# We only use the first line for now
# TODO: support multi-line values?
value = result.strip().split("\n")[0]
Expand Down Expand Up @@ -994,8 +982,6 @@ async def generate_flow(
Task.GENERATE_FLOW_CONTINUATION_FROM_NLD, output=result
)

result = result.text

result = _remove_leading_empty_lines(result)
lines = result.split("\n")
if "codeblock" in lines[0]:
Expand Down
2 changes: 0 additions & 2 deletions nemoguardrails/library/content_safety/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ async def content_safety_check_input(
)

result = llm_task_manager.parse_task_output(task, output=result)
result = result.text

is_safe, *violated_policies = result

Expand Down Expand Up @@ -165,7 +164,6 @@ async def content_safety_check_output(

result = llm_task_manager.parse_task_output(task, output=result)

result = result.text
is_safe, *violated_policies = result

return {"allowed": is_safe, "policy_violations": violated_policies}
1 change: 0 additions & 1 deletion nemoguardrails/library/self_check/facts/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ async def self_check_facts(
task, output=response, forced_output_parser="is_content_safe"
)

result = result.text
is_not_safe = result[0]

result = float(not is_not_safe)
Expand Down
1 change: 0 additions & 1 deletion nemoguardrails/library/self_check/input_check/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ async def self_check_input(
task, output=response, forced_output_parser="is_content_safe"
)

result = result.text
is_safe = result[0]

if not is_safe:
Expand Down
1 change: 0 additions & 1 deletion nemoguardrails/library/self_check/output_check/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ async def self_check_output(
task, output=response, forced_output_parser="is_content_safe"
)

result = result.text
is_safe = result[0]

return is_safe
107 changes: 0 additions & 107 deletions nemoguardrails/llm/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,6 @@
)


@dataclass
class ReasoningExtractionResult:
"""
Holds cleaned response text and optional chain-of-thought reasoning trace extracted from LLM output.
"""

text: str
reasoning_trace: Optional[str] = None


def colang(events: List[dict]) -> str:
"""Filter that turns an array of events into a colang history."""
return get_colang_history(events)
Expand Down Expand Up @@ -448,100 +438,3 @@ def conversation_to_events(conversation: List) -> List[dict]:
)

return events


def _find_token_positions_for_removal(
response: str, start_token: Optional[str], end_token: Optional[str]
) -> Tuple[int, int]:
"""Helper function to find token positions specifically for text removal.

This is useful, for example, to remove reasoning traces from a reasoning LLM response.

This is optimized for the removal use case:
1. Uses find() for first start token
2. Uses rfind() for last end token
3. Sets start_index to 0 if start token is missing

Args:
response(str): The text to search in
start_token(str): The token marking the start of text to remove
end_token(str): The token marking the end of text to remove

Returns:
A tuple of (start_index, end_index) marking the span to remove;
both indices are -1 if start_token and end_token are not provided.
"""
if not start_token or not end_token:
return -1, -1

start_index = response.find(start_token)
# if the start index is missing, this is probably a continuation of a bot message
# started in the prompt.
if start_index == -1:
start_index = 0

end_index = response.rfind(end_token)

return start_index, end_index


def find_reasoning_tokens_position(
response: str, start_token: Optional[str], end_token: Optional[str]
) -> Tuple[int, int]:
"""Finds the positions of the first start token and the last end token.

This is intended to find the outermost boundaries of potential
reasoning sections, typically for removal.

Args:
response(str): The text to search in.
start_token(Optional[str]): The token marking the start of reasoning.
end_token(Optional[str]): The token marking the end of reasoning.

Returns:
A tuple (start_index, end_index).
- start_index: Position of the first `start_token`, or 0 if not found.
- end_index: Position of the last `end_token`, or -1 if not found.
"""

return _find_token_positions_for_removal(response, start_token, end_token)


def extract_and_strip_trace(
response: str, start_token: str, end_token: str
) -> ReasoningExtractionResult:
"""Extracts and removes reasoning traces from the given text.

This function identifies reasoning traces in the text that are marked
by specific start and end tokens. It extracts these traces, removes
them from the original text, and returns both the cleaned text and
the extracted reasoning trace.

Args:
response (str): The text to process.
start_token (str): The token marking the start of a reasoning trace.
end_token (str): The token marking the end of a reasoning trace.

Returns:
ReasoningExtractionResult: An object containing the cleaned text
without reasoning traces and the extracted reasoning trace, if any.
"""

start_index, end_index = find_reasoning_tokens_position(
response, start_token, end_token
)
# handles invalid/empty tokens returned as (-1, -1)
if start_index == -1 and end_index == -1:
return ReasoningExtractionResult(text=response, reasoning_trace=None)
# end token is missing
if end_index == -1:
return ReasoningExtractionResult(text=response, reasoning_trace=None)
# extrace if tokens are present and start < end
if start_index < end_index:
reasoning_trace = response[start_index : end_index + len(end_token)]
cleaned_text = response[:start_index] + response[end_index + len(end_token) :]
return ReasoningExtractionResult(
text=cleaned_text, reasoning_trace=reasoning_trace
)

return ReasoningExtractionResult(text=response, reasoning_trace=None)
Loading