From 71d00f083fb59bda34c82b82eea85602c1710265 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 2 Sep 2025 11:17:40 -0500 Subject: [PATCH 1/2] Dummy commit to set up the chore/type-clean-guardrails PR and branch --- nemoguardrails/actions/llm/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 2a57e1c26..cd11e70a7 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -137,7 +137,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow): + def _extract_user_message_example(self, flow: Flow) -> None: """Heuristic to extract user message examples from a flow.""" elements = [ item From aa69ee9e3cb91b74452a2010607d7d1dab803894 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 9 Sep 2025 17:49:57 -0500 Subject: [PATCH 2/2] Cleaned integrations directory --- .../integrations/langchain/runnable_rails.py | 52 ++++++++++++------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/nemoguardrails/integrations/langchain/runnable_rails.py b/nemoguardrails/integrations/langchain/runnable_rails.py index 1eb282848..07bba63dd 100644 --- a/nemoguardrails/integrations/langchain/runnable_rails.py +++ b/nemoguardrails/integrations/langchain/runnable_rails.py @@ -15,9 +15,10 @@ from __future__ import annotations -from typing import Any, List, Optional +from typing import Any, List, Optional, Union, cast -from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models import BaseChatModel +from langchain_core.language_models.llms import BaseLLM from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.prompt_values import ChatPromptValue, StringPromptValue from langchain_core.runnables import Runnable @@ -27,14 +28,14 @@ from nemoguardrails import LLMRails, RailsConfig from nemoguardrails.integrations.langchain.utils import async_wrap -from nemoguardrails.rails.llm.options import GenerationOptions +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse class RunnableRails(Runnable[Input, Output]): def __init__( self, config: RailsConfig, - llm: Optional[BaseLanguageModel] = None, + llm: Optional[Union[BaseLLM, BaseChatModel]] = None, tools: Optional[List[Tool]] = None, passthrough: bool = True, runnable: Optional[Runnable] = None, @@ -67,12 +68,14 @@ def __init__( if self.passthrough_runnable: self._init_passthrough_fn() - def _init_passthrough_fn(self): + def _init_passthrough_fn(self) -> None: """Initialize the passthrough function for the LLM rails instance.""" async def passthrough_fn(context: dict, events: List[dict]): # First, we fetch the input from the context _input = context.get("passthrough_input") + if self.passthrough_runnable is None: + raise ValueError("No passthrough runnable provided") async_wrapped_invoke = async_wrap(self.passthrough_runnable.invoke) _output = await async_wrapped_invoke(_input, self.config, **self.kwargs) @@ -84,10 +87,11 @@ async def passthrough_fn(context: dict, events: List[dict]): return text, _output - self.rails.llm_generation_actions.passthrough_fn = passthrough_fn + # Dynamically assign passthrough_fn to avoid type checker issues + setattr(self.rails.llm_generation_actions, "passthrough_fn", passthrough_fn) - def __or__(self, other): - if isinstance(other, BaseLanguageModel): + def __or__(self, other) -> "RunnableRails[Input, Output]": # type: ignore[override] + if isinstance(other, (BaseLLM, BaseChatModel)): self.llm = other self.rails.update_llm(other) @@ -193,6 +197,9 @@ def invoke( res = self.rails.generate( messages=input_messages, options=GenerationOptions(output_vars=True) ) + # When using output_vars=True, rails.generate returns a GenerationResponse + if not isinstance(res, GenerationResponse): + raise Exception(f"Expected GenerationResponse, got {type(res)}") context = res.output_data result = res.response @@ -203,17 +210,16 @@ def invoke( result = result[0] if self.passthrough and self.passthrough_runnable: - passthrough_output = context.get("passthrough_output") + passthrough_output = context.get("passthrough_output") if context else None # If a rail was triggered (input or dialog), the passthrough_output # will not be set. In this case, we only set the output key to the # message that was received from the guardrail configuration. if passthrough_output is None: - passthrough_output = { - self.passthrough_bot_output_key: result["content"] - } + content = result.get("content") if isinstance(result, dict) else result + passthrough_output = {self.passthrough_bot_output_key: content} - bot_message = context.get("bot_message") + bot_message = context.get("bot_message") if context else None # We make sure that, if the output rails altered the bot message, we # replace it in the passthrough_output @@ -222,20 +228,28 @@ def invoke( elif isinstance(passthrough_output, dict): passthrough_output[self.passthrough_bot_output_key] = bot_message - return passthrough_output + return cast(Output, passthrough_output) else: if isinstance(input, ChatPromptValue): - return AIMessage(content=result["content"]) + content = result.get("content") if isinstance(result, dict) else result + # Ensure content is a string for AIMessage + content_str = str(content) if content is not None else "" + return cast(Output, AIMessage(content=content_str)) elif isinstance(input, StringPromptValue): if isinstance(result, dict): - return result["content"] + return cast(Output, result.get("content", "")) else: - return result + return cast(Output, result) elif isinstance(input, dict): user_input = input["input"] if isinstance(user_input, str): - return {"output": result["content"]} + content = ( + result.get("content") if isinstance(result, dict) else result + ) + return cast(Output, {"output": content}) elif isinstance(user_input, list): - return {"output": result} + return cast(Output, {"output": result}) + else: + raise ValueError(f"Unexpected user_input type: {type(user_input)}") else: raise ValueError(f"Unexpected input type: {type(input)}")