Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def init(self):
self._init_flows_index(),
)

def _extract_user_message_example(self, flow: Flow):
def _extract_user_message_example(self, flow: Flow) -> None:
"""Heuristic to extract user message examples from a flow."""
elements = [
item
Expand Down
52 changes: 33 additions & 19 deletions nemoguardrails/integrations/langchain/runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

from __future__ import annotations

from typing import Any, List, Optional
from typing import Any, List, Optional, Union, cast

from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_core.runnables import Runnable
Expand All @@ -27,14 +28,14 @@

from nemoguardrails import LLMRails, RailsConfig
from nemoguardrails.integrations.langchain.utils import async_wrap
from nemoguardrails.rails.llm.options import GenerationOptions
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse


class RunnableRails(Runnable[Input, Output]):
def __init__(
self,
config: RailsConfig,
llm: Optional[BaseLanguageModel] = None,
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
tools: Optional[List[Tool]] = None,
passthrough: bool = True,
runnable: Optional[Runnable] = None,
Expand Down Expand Up @@ -67,12 +68,14 @@ def __init__(
if self.passthrough_runnable:
self._init_passthrough_fn()

def _init_passthrough_fn(self):
def _init_passthrough_fn(self) -> None:
"""Initialize the passthrough function for the LLM rails instance."""

async def passthrough_fn(context: dict, events: List[dict]):
# First, we fetch the input from the context
_input = context.get("passthrough_input")
if self.passthrough_runnable is None:
raise ValueError("No passthrough runnable provided")
async_wrapped_invoke = async_wrap(self.passthrough_runnable.invoke)
_output = await async_wrapped_invoke(_input, self.config, **self.kwargs)

Expand All @@ -84,10 +87,11 @@ async def passthrough_fn(context: dict, events: List[dict]):

return text, _output

self.rails.llm_generation_actions.passthrough_fn = passthrough_fn
# Dynamically assign passthrough_fn to avoid type checker issues
setattr(self.rails.llm_generation_actions, "passthrough_fn", passthrough_fn)

def __or__(self, other):
if isinstance(other, BaseLanguageModel):
def __or__(self, other) -> "RunnableRails[Input, Output]": # type: ignore[override]
if isinstance(other, (BaseLLM, BaseChatModel)):
self.llm = other
self.rails.update_llm(other)

Expand Down Expand Up @@ -193,6 +197,9 @@ def invoke(
res = self.rails.generate(
messages=input_messages, options=GenerationOptions(output_vars=True)
)
# When using output_vars=True, rails.generate returns a GenerationResponse
if not isinstance(res, GenerationResponse):
raise Exception(f"Expected GenerationResponse, got {type(res)}")
context = res.output_data
result = res.response

Expand All @@ -203,17 +210,16 @@ def invoke(
result = result[0]

if self.passthrough and self.passthrough_runnable:
passthrough_output = context.get("passthrough_output")
passthrough_output = context.get("passthrough_output") if context else None

# If a rail was triggered (input or dialog), the passthrough_output
# will not be set. In this case, we only set the output key to the
# message that was received from the guardrail configuration.
if passthrough_output is None:
passthrough_output = {
self.passthrough_bot_output_key: result["content"]
}
content = result.get("content") if isinstance(result, dict) else result
passthrough_output = {self.passthrough_bot_output_key: content}

bot_message = context.get("bot_message")
bot_message = context.get("bot_message") if context else None

# We make sure that, if the output rails altered the bot message, we
# replace it in the passthrough_output
Expand All @@ -222,20 +228,28 @@ def invoke(
elif isinstance(passthrough_output, dict):
passthrough_output[self.passthrough_bot_output_key] = bot_message

return passthrough_output
return cast(Output, passthrough_output)
else:
if isinstance(input, ChatPromptValue):
return AIMessage(content=result["content"])
content = result.get("content") if isinstance(result, dict) else result
# Ensure content is a string for AIMessage
content_str = str(content) if content is not None else ""
return cast(Output, AIMessage(content=content_str))
elif isinstance(input, StringPromptValue):
if isinstance(result, dict):
return result["content"]
return cast(Output, result.get("content", ""))
else:
return result
return cast(Output, result)
elif isinstance(input, dict):
user_input = input["input"]
if isinstance(user_input, str):
return {"output": result["content"]}
content = (
result.get("content") if isinstance(result, dict) else result
)
return cast(Output, {"output": content})
elif isinstance(user_input, list):
return {"output": result}
return cast(Output, {"output": result})
else:
raise ValueError(f"Unexpected user_input type: {type(user_input)}")
else:
raise ValueError(f"Unexpected input type: {type(input)}")