diff --git a/nemoguardrails/actions/action_dispatcher.py b/nemoguardrails/actions/action_dispatcher.py index 9342dd628..b302eea2a 100644 --- a/nemoguardrails/actions/action_dispatcher.py +++ b/nemoguardrails/actions/action_dispatcher.py @@ -19,8 +19,9 @@ import inspect import logging import os +from importlib.machinery import ModuleSpec from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast from langchain.chains.base import Chain from langchain_core.runnables import Runnable @@ -51,7 +52,7 @@ def __init__( """ log.info("Initializing action dispatcher") - self._registered_actions = {} + self._registered_actions: Dict[str, Union[Type, Callable[..., Any]]] = {} if load_all_actions: # TODO: check for better way to find actions dir path or use constants.py @@ -78,9 +79,12 @@ def __init__( # Last, but not least, if there was a config path, we try to load actions # from there as well. if config_path: - config_path = config_path.split(",") - for path in config_path: - self.load_actions_from_path(Path(path.strip())) + split_config_path: List[str] = config_path.split(",") + + # Don't load actions if we have an empty list + if split_config_path: + for path in split_config_path: + self.load_actions_from_path(Path(path.strip())) # If there are any imported paths, we load the actions from there as well. if import_paths: @@ -120,26 +124,28 @@ def load_actions_from_path(self, path: Path): ) def register_action( - self, action: callable, name: Optional[str] = None, override: bool = True + self, action: Callable, name: Optional[str] = None, override: bool = True ): """Registers an action with the given name. Args: - action (callable): The action function. + action (Callable): The action function. name (Optional[str]): The name of the action. Defaults to None. override (bool): If an action already exists, whether it should be overridden or not. """ if name is None: action_meta = getattr(action, "action_meta", None) - name = action_meta["name"] if action_meta else action.__name__ + action_name = action_meta["name"] if action_meta else action.__name__ + else: + action_name = name # If we're not allowed to override, we stop. - if name in self._registered_actions and not override: + if action_name in self._registered_actions and not override: return - self._registered_actions[name] = action + self._registered_actions[action_name] = action - def register_actions(self, actions_obj: any, override: bool = True): + def register_actions(self, actions_obj: Any, override: bool = True): """Registers all the actions from the given object. Args: @@ -167,7 +173,7 @@ def has_registered(self, name: str) -> bool: name = self._normalize_action_name(name) return name in self.registered_actions - def get_action(self, name: str) -> callable: + def get_action(self, name: str) -> Optional[Callable]: """Get the registered action by name. Args: @@ -181,7 +187,7 @@ def get_action(self, name: str) -> callable: async def execute_action( self, action_name: str, params: Dict[str, Any] - ) -> Tuple[Union[str, Dict[str, Any]], str]: + ) -> Tuple[Union[Optional[str], Dict[str, Any]], str]: """Execute a registered action. Args: @@ -195,16 +201,21 @@ async def execute_action( action_name = self._normalize_action_name(action_name) if action_name in self._registered_actions: - log.info(f"Executing registered action: {action_name}") - fn = self._registered_actions.get(action_name, None) + log.info("Executing registered action: %s", action_name) + maybe_fn: Optional[Callable] = self._registered_actions.get( + action_name, None + ) + if not maybe_fn: + raise Exception(f"Action '{action_name}' is not registered.") + fn = cast(Callable, maybe_fn) # Actions that are registered as classes are initialized lazy, when # they are first used. if inspect.isclass(fn): fn = fn() self._registered_actions[action_name] = fn - if fn is not None: + if fn: try: # We support both functions and classes as actions if inspect.isfunction(fn) or inspect.ismethod(fn): @@ -245,7 +256,17 @@ async def execute_action( result = await runnable.ainvoke(input=params) else: # TODO: there should be a common base class here - result = fn.run(**params) + fn_run_func = getattr(fn, "run", None) + if not callable(fn_run_func): + raise Exception( + f"No 'run' method defined for action '{action_name}'." + ) + + fn_run_func_with_signature = cast( + Callable[[], Union[Optional[str], Dict[str, Any]]], + fn_run_func, + ) + result = fn_run_func_with_signature(**params) return result, "success" # We forward LLM Call exceptions @@ -288,6 +309,7 @@ def _load_actions_from_module(filepath: str): """ action_objects = {} filename = os.path.basename(filepath) + module = None if not os.path.isfile(filepath): log.error(f"{filepath} does not exist or is not a file.") @@ -298,13 +320,16 @@ def _load_actions_from_module(filepath: str): log.debug(f"Analyzing file {filename}") # Import the module from the file - spec = importlib.util.spec_from_file_location(filename, filepath) - if spec is None: + spec: Optional[ModuleSpec] = importlib.util.spec_from_file_location( + filename, filepath + ) + if not spec: log.error(f"Failed to create a module spec from {filepath}.") return action_objects module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + if spec.loader: + spec.loader.exec_module(module) # Loop through all members in the module and check for the `@action` decorator # If class has action decorator is_action class member is true @@ -313,19 +338,25 @@ def _load_actions_from_module(filepath: str): obj, "action_meta" ): try: - action_objects[obj.action_meta["name"]] = obj - log.info(f"Added {obj.action_meta['name']} to actions") + actionable_name: str = getattr(obj, "action_meta").get("name") + action_objects[actionable_name] = obj + log.info(f"Added {actionable_name} to actions") except Exception as e: log.error( - f"Failed to register {obj.action_meta['name']} in action dispatcher due to exception {e}" + f"Failed to register {name} in action dispatcher due to exception {e}" ) except Exception as e: + if module is None: + raise RuntimeError(f"Failed to load actions from module at {filepath}.") + if not module.__file__: + raise RuntimeError(f"No file found for module {module} at {filepath}.") + try: relative_filepath = Path(module.__file__).relative_to(Path.cwd()) except ValueError: relative_filepath = Path(module.__file__).resolve() log.error( - f"Failed to register {filename} from {relative_filepath} in action dispatcher due to exception: {e}" + f"Failed to register {filename} in action dispatcher due to exception: {e}" ) return action_objects diff --git a/nemoguardrails/actions/actions.py b/nemoguardrails/actions/actions.py index 8149b0974..0cf595145 100644 --- a/nemoguardrails/actions/actions.py +++ b/nemoguardrails/actions/actions.py @@ -14,27 +14,42 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, TypedDict, Union - - -class ActionMeta(TypedDict, total=False): +from typing import ( + Any, + Callable, + List, + Optional, + Protocol, + Type, + TypedDict, + TypeVar, + Union, + cast, +) + + +class ActionMeta(TypedDict): name: str is_system_action: bool execute_async: bool output_mapping: Optional[Callable[[Any], bool]] +# Create a TypeVar to represent the decorated function or class +T = TypeVar("T", bound=Union[Callable[..., Any], Type[Any]]) + + def action( is_system_action: bool = False, name: Optional[str] = None, execute_async: bool = False, output_mapping: Optional[Callable[[Any], bool]] = None, -) -> Callable[[Union[Callable, type]], Union[Callable, type]]: +) -> Callable[[T], T]: """Decorator to mark a function or class as an action. Args: is_system_action (bool): Flag indicating if the action is a system action. - name (Optional[str]): The name to associate with the action. + name (str): The name to associate with the action. execute_async: Whether the function should be executed in async mode. output_mapping (Optional[Callable[[Any], bool]]): A function to interpret the action's result. It accepts the return value (e.g. the first element of a tuple) and return True if the output @@ -44,7 +59,7 @@ def action( callable: The decorated function or class. """ - def decorator(fn_or_cls: Union[Callable, type]) -> Union[Callable, type]: + def decorator(fn_or_cls: Union[Callable, Type]) -> Union[Callable, Type]: """Inner decorator function to add metadata to the action. Args: @@ -52,8 +67,11 @@ def decorator(fn_or_cls: Union[Callable, type]) -> Union[Callable, type]: """ fn_or_cls_target = getattr(fn_or_cls, "__func__", fn_or_cls) + # Action name is optional for the decorator, but mandatory for ActionMeta TypedDict + action_name: str = cast(str, name or fn_or_cls.__name__) + action_meta: ActionMeta = { - "name": name or fn_or_cls.__name__, + "name": action_name, "is_system_action": is_system_action, "execute_async": execute_async, "output_mapping": output_mapping, @@ -62,7 +80,7 @@ def decorator(fn_or_cls: Union[Callable, type]) -> Union[Callable, type]: setattr(fn_or_cls_target, "action_meta", action_meta) return fn_or_cls - return decorator + return decorator # pyright: ignore (TODO - resolve how the Actionable Protocol doesn't resolve the issue) @dataclass diff --git a/nemoguardrails/actions/core.py b/nemoguardrails/actions/core.py index 368657d30..fd70f9363 100644 --- a/nemoguardrails/actions/core.py +++ b/nemoguardrails/actions/core.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Optional +from typing import Any, Dict, Optional from nemoguardrails.actions.actions import ActionResult, action from nemoguardrails.utils import new_event_dict @@ -37,13 +37,13 @@ async def create_event( ActionResult: An action result containing the created event. """ - event_dict = new_event_dict( + event_dict: Dict[str, Any] = new_event_dict( event["_type"], **{k: v for k, v in event.items() if k != "_type"} ) # We add basic support for referring variables as values for k, v in event_dict.items(): if isinstance(v, str) and v[0] == "$": - event_dict[k] = context.get(v[1:]) + event_dict[k] = context.get(v[1:], None) if context else None return ActionResult(events=[event_dict]) diff --git a/nemoguardrails/actions/langchain/safetools.py b/nemoguardrails/actions/langchain/safetools.py index bbcb05698..e1c553bbb 100644 --- a/nemoguardrails/actions/langchain/safetools.py +++ b/nemoguardrails/actions/langchain/safetools.py @@ -19,11 +19,27 @@ """ import logging +from typing import TYPE_CHECKING from nemoguardrails.actions.validation import validate_input, validate_response log = logging.getLogger(__name__) +# Include these outside the try .. except so the Type-checker knows they're always imported +if TYPE_CHECKING: + from langchain_community.utilities import ( + ApifyWrapper, + BingSearchAPIWrapper, + GoogleSearchAPIWrapper, + GoogleSerperAPIWrapper, + OpenWeatherMapAPIWrapper, + SearxSearchWrapper, + SerpAPIWrapper, + WikipediaAPIWrapper, + WolframAlphaAPIWrapper, + ZapierNLAWrapper, + ) + try: from langchain_community.utilities import ( ApifyWrapper, diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 377b0bc5e..a230e5ce3 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -21,9 +21,10 @@ import re import sys import threading +from dataclasses import asdict from functools import lru_cache from time import time -from typing import Callable, List, Optional, Union, cast +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union, cast from jinja2 import meta from jinja2.sandbox import SandboxedEnvironment @@ -113,8 +114,6 @@ def __init__( t = threading.Thread(target=asyncio.run, args=(self.init(),)) t.start() t.join() - else: - loop.run_until_complete(self.init()) self.llm_task_manager = llm_task_manager @@ -123,7 +122,7 @@ def __init__( # If set, in passthrough mode, this function will be used instead of # calling the LLM with the user input. - self.passthrough_fn = None + self.passthrough_fn: Optional[Callable[..., Awaitable[str]]] = None async def init(self): # For Colang 2.x we need to do some initial processing @@ -136,7 +135,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow): + def _extract_user_message_example(self, flow: Flow) -> None: """Heuristic to extract user message examples from a flow.""" elements = [ item @@ -148,43 +147,61 @@ def _extract_user_message_example(self, flow: Flow): el = elements[1] if isinstance(el, SpecOp): - if el.op == "match": - spec = cast(SpecOp, el).spec + spec_op: SpecOp = cast(SpecOp, el) + + if spec_op.op == "match": + # The SpecOp.spec type is Union[Spec, dict]. Convert Dict to Spec if it's provided + match_spec: Spec = ( + spec_op.spec + if isinstance(spec_op.spec, Spec) + else Spec(**cast(Dict, spec_op.spec)) + ) + if ( - not hasattr(spec, "name") - or spec.name != "UtteranceUserActionFinished" + not match_spec.name + or match_spec.name != "UtteranceUserActionFinished" ): return - if "final_transcript" not in spec.arguments: + if "final_transcript" not in match_spec.arguments: return # Extract the message and remove the double quotes - message = eval_expression(spec.arguments["final_transcript"], {}) + message = eval_expression(match_spec.arguments["final_transcript"], {}) if isinstance(message, str): self.user_messages[flow.name] = [message] - elif el.op == "await": - spec = cast(SpecOp, el).spec - if isinstance(spec, dict) and spec.get("_type") == "spec_or": - specs = spec.get("elements") - else: - assert isinstance(spec, Spec) - specs = [spec] - - for spec in specs: - if ( - not spec.name.startswith("user ") - or not spec.arguments - or not spec.arguments["$0"] - ): - continue + elif spec_op.op == "await": + # The SpecOp.spec type is Union[Spec, dict]. Need to convert to Dict to have `elements` field + # which isn't in the Spec definition + await_spec_dict: Dict[str, Any] = ( + asdict(spec_op.spec) + if isinstance(spec_op.spec, Spec) + else cast(Dict, spec_op.spec) + ) - message = eval_expression(spec.arguments["$0"], {}) - if isinstance(message, str): - if flow.name not in self.user_messages: - self.user_messages[flow.name] = [] - self.user_messages[flow.name].append(message) + if ( + isinstance(await_spec_dict, dict) + and await_spec_dict.get("_type") == "spec_or" + ): + specs = await_spec_dict.get("elements", None) + else: + specs = [await_spec_dict] + + if specs: + for spec in specs: + if ( + not spec["name"].startswith("user ") + or not spec["arguments"] + or not spec["arguments"]["$0"] + ): + continue + + message = eval_expression(spec["arguments"]["$0"], {}) + if isinstance(message, str): + if flow.name not in self.user_messages: + self.user_messages[flow.name] = [] + self.user_messages[flow.name].append(message) def _extract_bot_message_example(self, flow: Flow): # Quick heuristic to identify the user utterance examples @@ -192,23 +209,40 @@ def _extract_bot_message_example(self, flow: Flow): return el = flow.elements[1] + + if not isinstance(el, SpecOp): + return + + spec_op: SpecOp = cast(SpecOp, el) + spec: Dict[str, Any] = ( + asdict( + spec_op.spec + ) # TODO! Refactor this function as it's duplicated in many places + if isinstance(spec_op.spec, Spec) + else cast(Dict, spec_op.spec) + ) + if ( - not isinstance(el, SpecOp) - or not hasattr(el.spec, "name") - or el.spec.name != "UtteranceBotAction" - or "script" not in el.spec.arguments + not spec["name"] + or spec["name"] != "UtteranceUserActionFinished" + or "script" not in spec["arguments"] ): return # Extract the message and remove the double quotes - message = el.spec.arguments["script"][1:-1] + message = spec["arguments"]["script"][1:-1] self.bot_messages[flow.name] = [message] def _process_flows(self): """Process the provided flows to extract the user utterance examples.""" - flow: Flow - for flow in self.config.flows: + # Flows can be either Flow or Dict. Convert them all to Flow for following code + flows: List[Flow] = [ + cast(Flow, flow) if isinstance(flow, Flow) else Flow(**cast(Dict, flow)) + for flow in self.config.flows + ] + + for flow in flows: if flow.name.startswith("user "): self._extract_user_message_example(flow) @@ -302,6 +336,9 @@ async def _init_flows_index(self): def _get_general_instructions(self): """Helper to extract the general instruction.""" text = "" + if self.config.instructions is None: + return None + for instruction in self.config.instructions: if instruction.type == "general": text = instruction.content @@ -318,6 +355,9 @@ def _get_sample_conversation_two_turns(self): This is needed to be included to "seed" the conversation so that the model can follow the format more easily. """ + if self.config.sample_conversation is None: + return None + lines = self.config.sample_conversation.split("\n") i = 0 user_count = 0 @@ -343,7 +383,7 @@ async def generate_user_intent( events: List[dict], context: dict, config: RailsConfig, - llm: Optional[BaseLLM] = None, + llm: Optional[Union[BaseLLM, BaseChatModel]] = None, kb: Optional[KnowledgeBase] = None, ): """Generate the canonical form for what the user said i.e. user intent.""" @@ -354,10 +394,21 @@ async def generate_user_intent( ) # The last event should be the "StartInternalSystemAction" and the one before it the "UtteranceUserActionFinished". event = get_last_user_utterance_event(events) - assert event["type"] == "UserMessage" + if not event: + raise ValueError( + "No user message found in event stream. Unable to generate user intent." + ) + if event["type"] != "UserMessage": + raise ValueError( + f"Expected UserMessage event, but found {event['type']}. " + "Cannot generate user intent from this event type." + ) # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm + # This can be None as some code-paths use embedding lookups rather than LLM generation + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) streaming_handler = streaming_handler_var.get() @@ -414,7 +465,7 @@ async def generate_user_intent( ) else: results = await self.user_message_index.search( - text=text, max_results=5 + text=text, max_results=5, threshold=None ) # We add these in reverse order so the most relevant is towards the end. for result in reversed(results): @@ -436,7 +487,9 @@ async def generate_user_intent( # We make this call with temperature 0 to have it as deterministic as possible. result = await llm_call( - llm, prompt, llm_params={"temperature": self.config.lowest_temperature} + generation_llm, + prompt, + llm_params={"temperature": self.config.lowest_temperature}, ) # Parse the output using the associated parser @@ -517,14 +570,24 @@ async def generate_user_intent( # Initialize the LLMCallInfo object llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value)) - generation_options: GenerationOptions = generation_options_var.get() - llm_params = ( - generation_options and generation_options.llm_params - ) or {} + gen_options: Optional[ + GenerationOptions + ] = generation_options_var.get() + + llm_params = (gen_options and gen_options.llm_params) or {} + + streaming_handler: Optional[ + StreamingHandler + ] = streaming_handler_var.get() + + custom_callback_handlers = ( + [streaming_handler] if streaming_handler else None + ) + text = await llm_call( - llm, + generation_llm, prompt, - custom_callback_handlers=[streaming_handler_var.get()], + custom_callback_handlers=custom_callback_handlers, llm_params=llm_params, ) text = self.llm_task_manager.parse_task_output( @@ -555,14 +618,20 @@ async def generate_user_intent( context={"relevant_chunks": relevant_chunks}, ) - generation_options: GenerationOptions = generation_options_var.get() + generation_options: Optional[ + GenerationOptions + ] = generation_options_var.get() llm_params = ( generation_options and generation_options.llm_params ) or {} + custom_callback_handlers = ( + [streaming_handler] if streaming_handler else None + ) + result = await llm_call( - llm, + generation_llm, prompt, - custom_callback_handlers=[streaming_handler_var.get()], + custom_callback_handlers=custom_callback_handlers, stop=["User:"], llm_params=llm_params, ) @@ -600,7 +669,12 @@ async def generate_user_intent( async def _search_flows_index(self, text, max_results): """Search the index of flows.""" - results = await self.flows_index.search(text=text, max_results=10) + if self.flows_index is None: + raise RuntimeError("No flows index found to search") + + results = await self.flows_index.search( + text=text, max_results=10, threshold=None + ) # we filter the results to keep only unique flows flows = set() @@ -625,10 +699,16 @@ async def generate_next_step( log.info("Phase 2 :: Generating next step ...") # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) # The last event should be the "StartInternalSystemAction" and the one before it the "UserIntent". event = get_last_user_intent_event(events) + if event is None: + raise RuntimeError( + "No last user intent found from which to generate next step" + ) # Currently, we only predict next step after a user intent using LLM if event["type"] == "UserIntent": @@ -661,7 +741,9 @@ async def generate_next_step( # We use temperature 0 for next step prediction as well result = await llm_call( - llm, prompt, llm_params={"temperature": self.config.lowest_temperature} + generation_llm, + prompt, + llm_params={"temperature": self.config.lowest_temperature}, ) # Parse the output using the associated parser @@ -786,10 +868,13 @@ async def generate_bot_message( log.info("Phase 3 :: Generating bot message ...") # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) # The last event should be the "StartInternalSystemAction" and the one before it the "BotIntent". event = get_last_bot_intent_event(events) + assert event assert event["type"] == "BotIntent" bot_intent = event["intent"] context_updates = {} @@ -833,12 +918,20 @@ async def generate_bot_message( if self.config.rails.dialog.single_call.enabled: event = get_last_user_intent_event(events) + if not event: + raise RuntimeError( + "No last user intent found to generate bot message" + ) if event["type"] == "UserIntent": bot_message_event = event["additional_info"]["bot_message_event"] # We only need to use the bot message if it corresponds to the # generate bot intent as well. last_bot_intent = get_last_bot_intent_event(events) + if not last_bot_intent: + raise RuntimeError( + "No last bot intent found to generate bot message" + ) if ( last_bot_intent["intent"] @@ -922,14 +1015,20 @@ async def generate_bot_message( else: prompt = context.get("user_message") - generation_options: GenerationOptions = generation_options_var.get() - llm_params = ( - generation_options and generation_options.llm_params - ) or {} + gen_options: Optional[ + GenerationOptions + ] = generation_options_var.get() + llm_params = (gen_options and gen_options.llm_params) or {} + custom_callback_handlers = ( + [streaming_handler] if streaming_handler else None + ) + + if not prompt: + raise RuntimeError("No prompt found to generate bot message") result = await llm_call( - llm, + generation_llm, prompt, - custom_callback_handlers=[streaming_handler], + custom_callback_handlers=custom_callback_handlers, llm_params=llm_params, ) @@ -954,7 +1053,7 @@ async def generate_bot_message( # NOTE: disabling bot message index when there are no user messages if self.config.user_messages and self.bot_message_index: results = await self.bot_message_index.search( - text=event["intent"], max_results=5 + text=event["intent"], max_results=5, threshold=None ) # We add these in reverse order so the most relevant is towards the end. @@ -985,14 +1084,20 @@ async def generate_bot_message( # Initialize the LLMCallInfo object llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_BOT_MESSAGE.value)) - generation_options: GenerationOptions = generation_options_var.get() + generation_options: Optional[ + GenerationOptions + ] = generation_options_var.get() llm_params = ( generation_options and generation_options.llm_params ) or {} + custom_callback_handlers = ( + [streaming_handler] if streaming_handler else None + ) + result = await llm_call( - llm, + generation_llm, prompt, - custom_callback_handlers=[streaming_handler], + custom_callback_handlers=custom_callback_handlers, llm_params=llm_params, ) @@ -1060,7 +1165,9 @@ async def generate_value( :param llm: Custom llm model to generate_value """ # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) last_event = events[-1] assert last_event["type"] == "StartInternalSystemAction" @@ -1096,7 +1203,9 @@ async def generate_value( llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_VALUE.value)) result = await llm_call( - llm, prompt, llm_params={"temperature": self.config.lowest_temperature} + generation_llm, + prompt, + llm_params={"temperature": self.config.lowest_temperature}, ) # Parse the output using the associated parser @@ -1126,7 +1235,7 @@ async def generate_value( async def generate_intent_steps_message( self, events: List[dict], - llm: Optional[BaseLLM] = None, + llm: Optional[Union[BaseLLM, BaseChatModel]] = None, kb: Optional[KnowledgeBase] = None, ): """Generate all three main Guardrails phases with a single LLM call. @@ -1136,10 +1245,19 @@ async def generate_intent_steps_message( # The last event should be the "StartInternalSystemAction" and the one before it the "UtteranceUserActionFinished". event = get_last_user_utterance_event(events) - assert event["type"] == "UserMessage" - + if not event: + raise ValueError( + "No user message found in event stream. Unable to generate user intent." + ) + if event["type"] != "UserMessage": + raise ValueError( + f"Expected UserMessage event, but found {event['type']}. " + "Cannot generate user intent from this event type." + ) # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) streaming_handler = streaming_handler_var.get() @@ -1161,7 +1279,7 @@ async def generate_intent_steps_message( # Some of these intents might not have an associated flow and will be # skipped from the few-shot examples. intent_results = await self.user_message_index.search( - text=event["text"], max_results=10 + text=event["text"], max_results=10, threshold=None ) # We fill in the list of potential user intents @@ -1213,7 +1331,9 @@ async def generate_intent_steps_message( if self.bot_message_index: bot_messages_results = ( await self.bot_message_index.search( - text=bot_canonical_form, max_results=1 + text=bot_canonical_form, + max_results=1, + threshold=None, ) ) @@ -1273,7 +1393,7 @@ async def generate_intent_steps_message( await _streaming_handler.enable_buffering() asyncio.create_task( llm_call( - llm, + generation_llm, prompt, custom_callback_handlers=[_streaming_handler], stop=["\nuser ", "\nUser "], @@ -1299,12 +1419,15 @@ async def generate_intent_steps_message( LLMCallInfo(task=Task.GENERATE_INTENT_STEPS_MESSAGE.value) ) - generation_options: GenerationOptions = generation_options_var.get() + gen_options: Optional[GenerationOptions] = generation_options_var.get() + llm_params = (gen_options and gen_options.llm_params) or {} additional_params = { - **((generation_options and generation_options.llm_params) or {}), + **llm_params, "temperature": self.config.lowest_temperature, } - result = await llm_call(llm, prompt, llm_params=additional_params) + result = await llm_call( + generation_llm, prompt, llm_params=additional_params + ) # Parse the output using the associated parser result = self.llm_task_manager.parse_task_output( @@ -1318,9 +1441,11 @@ async def generate_intent_steps_message( # Get the next 2 non-empty lines, these should contain: # line 1 - user intent, line 2 - bot intent. # Afterwards we have the bot message. - next_three_lines = get_top_k_nonempty_lines(result, k=2) - user_intent = next_three_lines[0] if len(next_three_lines) > 0 else None - bot_intent = next_three_lines[1] if len(next_three_lines) > 1 else None + next_two_lines = get_top_k_nonempty_lines(result, k=2) + if not next_two_lines: + raise RuntimeError("Couldn't get last two lines to generate intent") + user_intent = next_two_lines[0] if len(next_two_lines) > 0 else None + bot_intent = next_two_lines[1] if len(next_two_lines) > 1 else None bot_message = None if bot_intent: pos = result.find(bot_intent) @@ -1384,9 +1509,9 @@ async def generate_intent_steps_message( llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value)) # We make this call with temperature 0 to have it as deterministic as possible. - generation_options: GenerationOptions = generation_options_var.get() - llm_params = (generation_options and generation_options.llm_params) or {} - result = await llm_call(llm, prompt, llm_params=llm_params) + gen_options: Optional[GenerationOptions] = generation_options_var.get() + llm_params = (gen_options and gen_options.llm_params) or {} + result = await llm_call(generation_llm, prompt, llm_params=llm_params) result = self.llm_task_manager.parse_task_output( Task.GENERAL, output=result diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index 03ea9ae38..c36899bb8 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -14,10 +14,12 @@ # limitations under the License. import re -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager +from langchain_core.runnables import RunnableConfig +from langchain_core.runnables.base import Runnable from nemoguardrails.colang.v2_x.lang.colang_ast import Flow from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents @@ -56,9 +58,10 @@ def _infer_model_name(llm: BaseLanguageModel): if isinstance(val, str): return val - if hasattr(llm, "model_kwargs") and isinstance(llm.model_kwargs, dict): + model_kwargs = getattr(llm, "model_kwargs", None) + if model_kwargs and isinstance(model_kwargs, Dict): for attr in ["model", "model_name", "name"]: - val = llm.model_kwargs.get(attr) + val = model_kwargs.get(attr) if isinstance(val, str): return val @@ -67,12 +70,12 @@ def _infer_model_name(llm: BaseLanguageModel): async def llm_call( - llm: BaseLanguageModel, + llm: Optional[BaseLanguageModel], prompt: Union[str, List[dict]], model_name: Optional[str] = None, model_provider: Optional[str] = None, stop: Optional[List[str]] = None, - custom_callback_handlers: Optional[List[AsyncCallbackHandler]] = None, + custom_callback_handlers: Optional[Sequence[AsyncCallbackHandler]] = None, llm_params: Optional[dict] = None, ) -> str: """Calls the LLM with a prompt and returns the generated text. @@ -89,16 +92,23 @@ async def llm_call( Returns: The generated text response """ + if llm is None: + raise LLMCallException("No LLM provided to llm_call()") _setup_llm_call_info(llm, model_name, model_provider) all_callbacks = _prepare_callbacks(custom_callback_handlers) - if llm_params and llm is not None: - llm = llm.bind(**llm_params) + generation_llm: Union[BaseLanguageModel, Runnable] = ( + llm.bind(stop=stop, **llm_params) if llm_params and llm is not None else llm + ) if isinstance(prompt, str): - response = await _invoke_with_string_prompt(llm, prompt, all_callbacks, stop) + response = await _invoke_with_string_prompt( + generation_llm, prompt, all_callbacks + ) else: - response = await _invoke_with_message_list(llm, prompt, all_callbacks, stop) + response = await _invoke_with_message_list( + generation_llm, prompt, all_callbacks + ) _store_tool_calls(response) _store_response_metadata(response) @@ -119,42 +129,40 @@ def _setup_llm_call_info( def _prepare_callbacks( - custom_callback_handlers: Optional[List[AsyncCallbackHandler]], + custom_callback_handlers: Optional[Sequence[AsyncCallbackHandler]], ) -> BaseCallbackManager: """Prepare callback manager with custom handlers if provided.""" if custom_callback_handlers and custom_callback_handlers != [None]: return BaseCallbackManager( - handlers=logging_callbacks.handlers + custom_callback_handlers, - inheritable_handlers=logging_callbacks.handlers + custom_callback_handlers, + handlers=logging_callbacks.handlers + list(custom_callback_handlers), + inheritable_handlers=logging_callbacks.handlers + + list(custom_callback_handlers), ) return logging_callbacks async def _invoke_with_string_prompt( - llm: BaseLanguageModel, + llm: Union[BaseLanguageModel, Runnable], prompt: str, callbacks: BaseCallbackManager, - stop: Optional[List[str]], ): """Invoke LLM with string prompt.""" try: - return await llm.ainvoke(prompt, config={"callbacks": callbacks, "stop": stop}) + return await llm.ainvoke(prompt, config=RunnableConfig(callbacks=callbacks)) except Exception as e: raise LLMCallException(e) async def _invoke_with_message_list( - llm: BaseLanguageModel, + llm: Union[BaseLanguageModel, Runnable], prompt: List[dict], callbacks: BaseCallbackManager, - stop: Optional[List[str]], ): """Invoke LLM with message list after converting to LangChain format.""" messages = _convert_messages_to_langchain_format(prompt) + try: - return await llm.ainvoke( - messages, config={"callbacks": callbacks, "stop": stop} - ) + return await llm.ainvoke(messages, config=RunnableConfig(callbacks=callbacks)) except Exception as e: raise LLMCallException(e) diff --git a/nemoguardrails/actions/retrieve_relevant_chunks.py b/nemoguardrails/actions/retrieve_relevant_chunks.py index 46b178aed..16d9093f4 100644 --- a/nemoguardrails/actions/retrieve_relevant_chunks.py +++ b/nemoguardrails/actions/retrieve_relevant_chunks.py @@ -52,7 +52,7 @@ async def retrieve_relevant_chunks( ``` """ - user_message = context.get("last_user_message") + user_message: Optional[str] = context.get("last_user_message") if context else None context_updates = {} if user_message and kb: @@ -72,14 +72,18 @@ async def retrieve_relevant_chunks( else: # No KB is set up, we keep the existing relevant_chunks if we have them. if is_colang_2: - context_updates["relevant_chunks"] = context.get("relevant_chunks", "") + context_updates["relevant_chunks"] = ( + context.get("relevant_chunks", "") if context else None + ) if context_updates["relevant_chunks"]: context_updates["relevant_chunks"] += "\n" else: context_updates["relevant_chunks"] = ( - context.get("relevant_chunks", "") + "\n" + (context.get("relevant_chunks", "") + "\n") if context else None ) - context_updates["relevant_chunks_sep"] = context.get("relevant_chunks_sep", []) + context_updates["relevant_chunks_sep"] = ( + context.get("relevant_chunks_sep", []) if context else None + ) context_updates["retrieved_for"] = None return ActionResult( diff --git a/nemoguardrails/actions/summarize_document.py b/nemoguardrails/actions/summarize_document.py index 8ad1c6763..44937ba2c 100644 --- a/nemoguardrails/actions/summarize_document.py +++ b/nemoguardrails/actions/summarize_document.py @@ -15,7 +15,7 @@ from langchain.chains import AnalyzeDocumentChain from langchain.chains.summarize import load_summarize_chain -from langchain.llms import BaseLLM +from langchain_core.language_models.llms import BaseLLM from nemoguardrails.actions.actions import action diff --git a/nemoguardrails/actions/v2_x/generation.py b/nemoguardrails/actions/v2_x/generation.py index 5999ac81f..72e703a2c 100644 --- a/nemoguardrails/actions/v2_x/generation.py +++ b/nemoguardrails/actions/v2_x/generation.py @@ -19,8 +19,9 @@ import re import textwrap from ast import literal_eval -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union, cast +from langchain_core.language_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM from rich.text import Text @@ -32,12 +33,11 @@ get_first_bot_intent, get_first_nonempty_line, get_first_user_intent, - get_initial_actions, get_last_user_utterance_event_v2_x, llm_call, remove_action_intent_identifiers, ) -from nemoguardrails.colang.v2_x.lang.colang_ast import Flow +from nemoguardrails.colang.v2_x.lang.colang_ast import Flow, Spec, SpecOp from nemoguardrails.colang.v2_x.runtime.errors import LlmResponseError from nemoguardrails.colang.v2_x.runtime.flows import ActionEvent, InternalEvent from nemoguardrails.colang.v2_x.runtime.statemachine import ( @@ -60,6 +60,7 @@ from nemoguardrails.logging import verbose from nemoguardrails.logging.explain import LLMCallInfo from nemoguardrails.rails.llm.options import GenerationOptions +from nemoguardrails.streaming import StreamingHandler from nemoguardrails.utils import console, new_uuid log = logging.getLogger(__name__) @@ -122,15 +123,23 @@ async def _init_flows_index(self) -> None: # The list of flows that have instructions, i.e. docstring at the beginning. instruction_flows = [] - for flow in self.config.flows: - colang_flow = flow.get("source_code") + # RailsConfig flow can be either Dict or Flow. Convert dicts to Flow for rest of the function + typed_flow: Flow = ( + Flow(**cast(Dict, flow)) if isinstance(flow, Dict) else flow + ) + colang_flow = typed_flow.source_code if colang_flow: - assert isinstance(flow, Flow) # Check if we need to exclude this flow. - if flow.file_info.get("exclude_from_llm") or ( - "meta" in flow.decorators - and flow.decorators["meta"].parameters.get("llm_exclude") + + has_llm_exclude_parameter: bool = any( + [ + "llm_exclude" in decorator.parameters + for decorator in typed_flow.decorators + ] + ) + if typed_flow.file_info.get("exclude_from_llm") or ( + "meta" in typed_flow.decorators and has_llm_exclude_parameter ): continue @@ -203,7 +212,8 @@ async def _collect_user_intent_and_examples( # We add all currently active user intents (heads on match statements) heads = find_all_active_event_matchers(state) for head in heads: - element = get_element_from_head(state, head) + el = get_element_from_head(state, head) + element = el if isinstance(el, SpecOp) else SpecOp(**cast(Dict, el)) flow_state = state.flow_states[head.flow_state_uid] event = get_event_from_element(state, flow_state, element) if ( @@ -222,10 +232,11 @@ async def _collect_user_intent_and_examples( and "_user_intent" in element_flow_state_instance[0].context ): if flow_config.elements[1]["_type"] == "doc_string_stmt": + # TODO! Need to make this type-safe but no idea what's going on examples += "user action: <" + ( - flow_config.elements[1]["elements"][0]["elements"][0][ - "elements" - ][0][3:-3] + flow_config.elements[1]["elements"][ # pyright: ignore + 0 + ]["elements"][0]["elements"][0][3:-3] + ">\n" ) examples += f"user intent: {flow_id}\n\n" @@ -250,7 +261,7 @@ async def get_last_user_message( return event["final_transcript"] @action(name="GenerateUserIntentAction", is_system_action=True, execute_async=True) - async def generate_user_intent( + async def generate_user_intent( # pyright: ignore (TODO - Signature completely different to base class) self, state: State, events: List[dict], @@ -261,7 +272,9 @@ async def generate_user_intent( """Generate the canonical form for what the user said i.e. user intent.""" # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) log.info("Phase 1 :: Generating user intent") ( @@ -294,7 +307,7 @@ async def generate_user_intent( # We make this call with lowest temperature to have it as deterministic as possible. result = await llm_call( - llm, + generation_llm, prompt, stop=stop, llm_params={"temperature": self.config.lowest_temperature}, @@ -342,8 +355,9 @@ async def generate_user_intent_and_bot_action( """Generate the canonical form for what the user said i.e. user intent and a suitable bot action.""" # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm - + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) log.info("Phase 1 :: Generating user intent and bot action") ( @@ -376,7 +390,7 @@ async def generate_user_intent_and_bot_action( # We make this call with lowest temperature to have it as deterministic as possible. result = await llm_call( - llm, + generation_llm, prompt, stop=stop, llm_params={"temperature": self.config.lowest_temperature}, @@ -429,7 +443,14 @@ async def passthrough_llm_action( events: List[dict], llm: Optional[BaseLLM] = None, ): + if not llm: + raise RuntimeError("No LLM provided to passthrough LLM Action") + event = get_last_user_utterance_event_v2_x(events) + if not event: + raise RuntimeError( + "Passthrough LLM Action couldn't find last user utterance" + ) # We check if we have a raw request. If the guardrails API is using # the `generate_events` API, this will not be set. @@ -455,20 +476,21 @@ async def passthrough_llm_action( # Initialize the LLMCallInfo object llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value)) - generation_options: GenerationOptions = generation_options_var.get() + generation_options: Optional[GenerationOptions] = generation_options_var.get() + + streaming_handler: Optional[StreamingHandler] = streaming_handler_var.get() + custom_callback_handlers = [streaming_handler] if streaming_handler else None generation_llm_params = generation_options and generation_options.llm_params text = await llm_call( llm, user_message, - custom_callback_handlers=[streaming_handler_var.get()], + custom_callback_handlers=custom_callback_handlers, llm_params=generation_llm_params, ) text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=text) - text = result.text - return text @action(name="CheckValidFlowExistsAction", is_system_action=True) @@ -514,12 +536,13 @@ async def generate_flow_from_instructions( raise RuntimeError("No instruction flows index has been created.") # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm - + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) log.info("Generating flow for instructions: %s", instructions) results = await self.instruction_flows_index.search( - text=instructions, max_results=5 + text=instructions, max_results=5, threshold=None ) examples = "" @@ -546,7 +569,9 @@ async def generate_flow_from_instructions( # We make this call with temperature 0 to have it as deterministic as possible. result = await llm_call( - llm, prompt, llm_params={"temperature": self.config.lowest_temperature} + generation_llm, + prompt, + llm_params={"temperature": self.config.lowest_temperature}, ) result = self.llm_task_manager.parse_task_output( @@ -593,12 +618,16 @@ async def generate_flow_from_name( raise RuntimeError("No flows index has been created.") # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm - + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) log.info("Generating flow for name: {name}") + if not self.instruction_flows_index: + raise RuntimeError("No instruction flows index has been created.") + results = await self.instruction_flows_index.search( - text=f"flow {name}", max_results=5 + text=f"flow {name}", max_results=5, threshold=None ) examples = "" @@ -621,7 +650,7 @@ async def generate_flow_from_name( # We make this call with temperature 0 to have it as deterministic as possible. result = await llm_call( - llm, + generation_llm, prompt, stop=stop, llm_params={"temperature": self.config.lowest_temperature}, @@ -659,7 +688,9 @@ async def generate_flow_continuation( raise RuntimeError("No instruction flows index has been created.") # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) log.info("Generating flow continuation.") @@ -668,7 +699,11 @@ async def generate_flow_continuation( # We use the last line from the history to search for relevant flows search_text = colang_history.split("\n")[-1] - results = await self.flows_index.search(text=search_text, max_results=10) + if self.flows_index is None: + raise RuntimeError("No flows index has been created.") + results = await self.flows_index.search( + text=search_text, max_results=10, threshold=None + ) examples = "" for result in reversed(results): @@ -690,7 +725,9 @@ async def generate_flow_continuation( ) # We make this call with temperature 0 to have it as deterministic as possible. - result = await llm_call(llm, prompt, llm_params={"temperature": temperature}) + result = await llm_call( + generation_llm, prompt, llm_params={"temperature": temperature} + ) # TODO: Currently, we only support generating a bot action as continuation. This could be generalized # Colang statements. @@ -767,7 +804,7 @@ async def create_flow( } @action(name="GenerateValueAction", is_system_action=True, execute_async=True) - async def generate_value( + async def generate_value( # pyright: ignore (TODO - different arguments to base-class) self, state: State, instructions: str, @@ -783,22 +820,26 @@ async def generate_value( :param llm: Custom llm model to generate_value """ # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) # We search for the most relevant flows. examples = "" if self.flows_index: + results = None if var_name: results = await self.flows_index.search( - text=f"${var_name} = ", max_results=5 + text=f"${var_name} = ", max_results=5, threshold=None ) # We add these in reverse order so the most relevant is towards the end. - for result in reversed(results): - # If the flow includes "GenerateValueAction", we ignore it as we don't want the LLM - # to learn to predict it. - if "GenerateValueAction" not in result.text: - examples += f"{result.text}\n\n" + if results: + for result in reversed(results): + # If the flow includes "GenerateValueAction", we ignore it as we don't want the LLM + # to learn to predict it. + if "GenerateValueAction" not in result.text: + examples += f"{result.text}\n\n" llm_call_info_var.set( LLMCallInfo(task=Task.GENERATE_VALUE_FROM_INSTRUCTION.value) @@ -819,7 +860,9 @@ async def generate_value( Task.GENERATE_USER_INTENT_FROM_USER_ACTION ) - result = await llm_call(llm, prompt, stop=stop, llm_params={"temperature": 0.1}) + result = await llm_call( + generation_llm, prompt, stop=stop, llm_params={"temperature": 0.1} + ) # Parse the output using the associated parser result = self.llm_task_manager.parse_task_output( @@ -862,11 +905,19 @@ async def generate_flow( ) -> dict: """Generate the body for a flow.""" # Use action specific llm if registered else fallback to main llm - llm = llm or self.llm + generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = ( + llm if llm else self.llm + ) triggering_flow_id = flow_id + if not triggering_flow_id: + raise RuntimeError( + "No flow_id provided to generate flow." + ) # TODO! Should flow_id be mandatory? flow_config = state.flow_configs[triggering_flow_id] + if not flow_config.source_code: + raise RuntimeError(f"No source_code in flow_config {flow_config}") docstrings = re.findall(r'"""(.*?)"""', flow_config.source_code, re.DOTALL) if len(docstrings) > 0: @@ -888,6 +939,10 @@ async def generate_flow( for flow_config in state.flow_configs.values(): if flow_config.decorators.get("meta", {}).get("tool") is True: # We get rid of the first line, which is the decorator + + if not flow_config.source_code: + raise Exception(f"No source_code in flow_config {flow_config}") + body = flow_config.source_code.split("\n", maxsplit=1)[1] # We only need the part up to the docstring @@ -928,7 +983,7 @@ async def generate_flow( ) result = await llm_call( - llm, + generation_llm, prompt, stop=stop, llm_params={"temperature": self.config.lowest_temperature}, diff --git a/nemoguardrails/actions/validation/base.py b/nemoguardrails/actions/validation/base.py index 572ea5528..a92fd6673 100644 --- a/nemoguardrails/actions/validation/base.py +++ b/nemoguardrails/actions/validation/base.py @@ -14,7 +14,7 @@ # limitations under the License. import json import re -from typing import List +from typing import List, Sequence from urllib.parse import quote from nemoguardrails.actions.validation.filter_secrets import contains_secrets @@ -22,7 +22,7 @@ MAX_LEN = 50 -def validate_input(attribute: str, validators: List[str] = (), **validation_args): +def validate_input(attribute: str, validators: Sequence[str] = (), **validation_args): """A generic decorator that can be used by any action (class method or function) for input validation. Supported validation choices are: length and quote. diff --git a/nemoguardrails/actions/validation/filter_secrets.py b/nemoguardrails/actions/validation/filter_secrets.py index ff6132332..8b4cb10c3 100644 --- a/nemoguardrails/actions/validation/filter_secrets.py +++ b/nemoguardrails/actions/validation/filter_secrets.py @@ -22,7 +22,7 @@ def contains_secrets(resp): ArtifactoryDetector : False """ try: - import detect_secrets + import detect_secrets # type: ignore (Assume user installs detect_secrets with instructions below) except ModuleNotFoundError: raise ValueError( "Could not import detect_secrets. Please install using `pip install detect-secrets`" diff --git a/nemoguardrails/context.py b/nemoguardrails/context.py index 2e7d34b82..b9f04b5f9 100644 --- a/nemoguardrails/context.py +++ b/nemoguardrails/context.py @@ -17,7 +17,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.rails.llm.options import GenerationOptions +from nemoguardrails.streaming import StreamingHandler +streaming_handler_var: contextvars.ContextVar[ + Optional[StreamingHandler] +] = contextvars.ContextVar("streaming_handler", default=None) if TYPE_CHECKING: from nemoguardrails.logging.explain import ExplainInfo from nemoguardrails.logging.stats import LLMStats @@ -40,7 +45,7 @@ # All the generation options applicable to the current context. generation_options_var: contextvars.ContextVar[ - Optional["GenerationOptions"] + Optional[GenerationOptions] ] = contextvars.ContextVar("generation_options", default=None) # The stats about the LLM calls. diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 6c5073a78..749ecfd32 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -492,7 +492,7 @@ class OutputRails(BaseModel): description="Configuration for streaming output rails.", ) - apply_to_reasoning_traces: Optional[bool] = Field( + apply_to_reasoning_traces: bool = Field( default=False, description=( "If True, output rails will apply guardrails to both reasoning traces and output response. " diff --git a/pyproject.toml b/pyproject.toml index c80c38e50..cd96bdf32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,7 +155,9 @@ pyright = "^1.1.405" # Directories in which to run Pyright type-checking [tool.pyright] -include = ["nemoguardrails/rails/**", "tests/test_callbacks.py"] +include = ["nemoguardrails/rails/**", + "nemoguardrails/actions/**", + "tests/test_callbacks.py"] [tool.poetry.group.docs] optional = true diff --git a/tests/test_action_dispatcher.py b/tests/test_action_dispatcher.py index 59fa9cef1..9e21c986f 100644 --- a/tests/test_action_dispatcher.py +++ b/tests/test_action_dispatcher.py @@ -131,3 +131,50 @@ def test_load_actions_from_module_relative_path_exception(monkeypatch): assert "invalid syntax" in error_message assert "exception:" in error_message + + +@pytest.mark.asyncio +async def test_execute_missing_action_raises(): + """Create an action with name but no function to call, check it raises an exception""" + + missing_action_name = "missing_test_action" + dispatcher = ActionDispatcher(load_all_actions=False) + dispatcher.register_action(None, name=missing_action_name) + + with pytest.raises(Exception, match="is not registered."): + _ = await dispatcher.execute_action(missing_action_name, params={}) + + +@pytest.mark.asyncio +async def test_execute_action_not_callable_raises(caplog): + """Register a function with a "run" attribute that isn't callable""" + + action_name = "uncallable_test_action" + dispatcher = ActionDispatcher(load_all_actions=False) + dispatcher.register_action({"run": "not callable"}, name=action_name) + + # No Exception is raised, it gets caught and logged out as an error instead + result = await dispatcher.execute_action(action_name, params={}) + assert result == (None, "failed") + last_log = caplog.records[-1] + assert last_log.levelname == "ERROR" + assert last_log.message == f"No 'run' method defined for action '{action_name}'." + + +@pytest.mark.asyncio +async def test_execute_action_with_signature(): + """Register a function with a "run" attribute that **is** callable""" + + action_name = "callable_test_action" + action_return = "The callable test action was just called!" + + class test_class: + def run(self): + return action_return + + dispatcher = ActionDispatcher(load_all_actions=False) + dispatcher.register_action(test_class, name=action_name) + + # No Exception is raised, it gets caught and logged out as an error instead + result = await dispatcher.execute_action(action_name, params={}) + assert result == (action_return, "success") diff --git a/tests/test_general_instructions.py b/tests/test_general_instructions.py index f4395e3eb..637553633 100644 --- a/tests/test_general_instructions.py +++ b/tests/test_general_instructions.py @@ -13,7 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import MagicMock + +import pytest + from nemoguardrails import RailsConfig +from nemoguardrails.actions.llm.generation import LLMGenerationActions +from nemoguardrails.llm.taskmanager import LLMTaskManager +from nemoguardrails.rails.llm.config import Instruction, Model, RailsConfig from tests.utils import TestChat @@ -44,3 +51,144 @@ def test_general_instructions_get_included_when_no_canonical_forms_are_defined() assert ( "This is a conversation between a user and a bot." in info.llm_calls[0].prompt ) + + +def test_get_general_instructions_none(): + """Check we get None when RailsConfig.instructions is None.""" + + config = RailsConfig( + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], + colang_version="1.0", + instructions=None, + ) + + actions = LLMGenerationActions( + config, + llm=None, + llm_task_manager=MagicMock(spec=LLMTaskManager), + get_embedding_search_provider_instance=MagicMock(), + ) + + instructions = actions._get_general_instructions() + assert instructions is None + + +def test_get_general_instructions_empty_list(): + """Check an empty list of instructions returns an empty string""" + + config = RailsConfig( + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], + colang_version="1.0", + ) + config.instructions = [] + + actions = LLMGenerationActions( + config, + llm=None, + llm_task_manager=MagicMock(spec=LLMTaskManager), + get_embedding_search_provider_instance=MagicMock(), + ) + + instructions = actions._get_general_instructions() + assert instructions == "" + + +def test_get_general_instructions_list(): + """Check a list of instructions where the second one is general""" + + first_general_instruction = "Don't answer with any inappropriate content." + instructions = [ + Instruction(type="specific", content="You're a helpful bot "), + Instruction(type="general", content=first_general_instruction), + ] + + config = RailsConfig( + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], + colang_version="1.0", + instructions=instructions, + ) + + actions = LLMGenerationActions( + config, + llm=None, + llm_task_manager=MagicMock(spec=LLMTaskManager), + get_embedding_search_provider_instance=MagicMock(), + ) + + instructions = actions._get_general_instructions() + assert instructions == first_general_instruction + + +def test_get_sample_conversation_two_turns(): + """Check if the RailsConfig sample_conversation is None we get None back""" + + config = RailsConfig( + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], + colang_version="1.0", + sample_conversation=None, + ) + + actions = LLMGenerationActions( + config, + llm=None, + llm_task_manager=MagicMock(spec=LLMTaskManager), + get_embedding_search_provider_instance=MagicMock(), + ) + + conversation = actions._get_sample_conversation_two_turns() + assert conversation is None + + +@pytest.mark.asyncio +async def test_search_flows_index_is_none(): + """Check if we try and search the flows index when None we get None back""" + + config = RailsConfig( + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], + colang_version="1.0", + sample_conversation=None, + ) + + actions = LLMGenerationActions( + config, + llm=None, + llm_task_manager=MagicMock(spec=LLMTaskManager), + get_embedding_search_provider_instance=MagicMock(), + ) + + with pytest.raises(RuntimeError, match="No flows index found to search"): + _ = await actions._search_flows_index(text="default action", max_results=1) + + +@pytest.mark.asyncio +async def test_generate_next_step_empty_event_list(): + """Check if we try and search the flows index when None we get None back""" + + config = RailsConfig( + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], + colang_version="1.0", + sample_conversation=None, + ) + + actions = LLMGenerationActions( + config, + llm=None, + llm_task_manager=MagicMock(spec=LLMTaskManager), + get_embedding_search_provider_instance=MagicMock(), + ) + + with pytest.raises( + RuntimeError, match="No last user intent found from which to generate next step" + ): + _ = await actions.generate_next_step(events=[]) + + +# +# @pytest.mark.asyncio +# async def test_generate_next_step_last_user_intent_is_none(): +# +# # +# events = [{"type": "UserIntent", "content": "You're a helpful bot "} +# {"type": "UtteranceUserActionFinished", "final_transcript": "Hello!"}] +# +# actions._generate_next_step = MagicMock(return_value="default action") diff --git a/tests/test_tool_calling_utils.py b/tests/test_tool_calling_utils.py index d8c96d574..0381086e1 100644 --- a/tests/test_tool_calling_utils.py +++ b/tests/test_tool_calling_utils.py @@ -255,7 +255,7 @@ async def test_llm_call_with_llm_params(): result = await llm_call(mock_llm, "Test prompt", llm_params=llm_params) assert result == "LLM response with params" - mock_llm.bind.assert_called_once_with(**llm_params) + mock_llm.bind.assert_called_once_with(stop=None, **llm_params) mock_bound_llm.ainvoke.assert_called_once() @@ -304,7 +304,7 @@ async def test_llm_call_with_llm_params_temperature_max_tokens(): result = await llm_call(mock_llm, "Test prompt", llm_params=llm_params) assert result == "Response with temp and tokens" - mock_llm.bind.assert_called_once_with(temperature=0.8, max_tokens=50) + mock_llm.bind.assert_called_once_with(stop=None, temperature=0.8, max_tokens=50) mock_bound_llm.ainvoke.assert_called_once() diff --git a/tests/v2_x/test_passthroug_mode.py b/tests/v2_x/test_passthroug_mode.py index b4e0ff3df..0466eb997 100644 --- a/tests/v2_x/test_passthroug_mode.py +++ b/tests/v2_x/test_passthroug_mode.py @@ -81,6 +81,9 @@ def test_passthrough_llm_action_not_invoked_via_logs(self): self.assertIn("content", response) self.assertIsInstance(response["content"], str) + @unittest.skip( + reason="Github issue https://github.com/NVIDIA/NeMo-Guardrails/issues/1378" + ) def test_passthrough_llm_action_invoked_via_logs(self): chat = TestChat( config,