diff --git a/nemoguardrails/rails/llm/llm_flows.co b/nemoguardrails/rails/llm/llm_flows.co index c93a5e3bb..1753a3727 100644 --- a/nemoguardrails/rails/llm/llm_flows.co +++ b/nemoguardrails/rails/llm/llm_flows.co @@ -28,11 +28,10 @@ define flow run dialog rails # If the dialog_rails are disabled if $generation_options and $generation_options.rails.dialog == False - # If the output rails are also disabled, we just return user message. - if $generation_options.rails.output == False + # If output rails are disabled or there's no bot message to check, skip output rails. + if $generation_options.rails.output == False or $bot_message is None create event StartUtteranceBotAction(script=$user_message) else - # we take the $bot_message from context. create event BotMessage(text=$bot_message) else # If not, we continue the usual process diff --git a/tests/test_generation_options.py b/tests/test_generation_options.py index 0b01f63fc..fd990ebff 100644 --- a/tests/test_generation_options.py +++ b/tests/test_generation_options.py @@ -344,3 +344,74 @@ def test_generation_log_print_summary(capsys): capture_lines[8] == "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens." ) + + +@pytest.mark.parametrize( + "input_opt,output_opt,dialog_opt,expect_input,expect_output", + [ + (True, True, True, True, True), + (True, True, False, True, False), + (True, False, True, True, False), + (True, False, False, True, False), + (False, True, True, False, True), + (False, True, False, False, False), + (False, False, True, False, False), + (False, False, False, False, False), + ], +) +@pytest.mark.asyncio +async def test_rails_options_combinations(input_opt, output_opt, dialog_opt, expect_input, expect_output): + """ + Test all combinations of input/output/dialog options. + When dialog=False and no bot_message is provided, output rails should skip. + """ + config = RailsConfig.from_content( + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define subflow dummy input rail + if "block" in $user_message + bot refuse to respond + stop + + define subflow dummy output rail + if "block" in $bot_message + bot refuse to respond + stop + """, + yaml_content=""" + rails: + input: + flows: + - dummy input rail + output: + flows: + - dummy output rail + """, + ) + chat = TestChat( + config, + llm_completions=[" express greeting", ' "Hello!"'] if dialog_opt else [], + ) + + res: GenerationResponse = await chat.app.generate_async( + "Hello!", + options={ + "rails": {"input": input_opt, "output": output_opt, "dialog": dialog_opt}, + "log": {"activated_rails": True}, + }, + ) + + activated_rails = res.log.activated_rails if res.log else [] + rail_names = [r.name for r in activated_rails] + + input_rails_ran = any("input" in name.lower() for name in rail_names) + output_rails_ran = any("output" in name.lower() for name in rail_names) + + assert input_rails_ran == expect_input, f"Input rails: expected {expect_input}, got {rail_names}" + assert output_rails_ran == expect_output, f"Output rails: expected {expect_output}, got {rail_names}" diff --git a/tests/test_parallel_rails_exceptions.py b/tests/test_parallel_rails_exceptions.py index 80cf43f00..3d94d8036 100644 --- a/tests/test_parallel_rails_exceptions.py +++ b/tests/test_parallel_rails_exceptions.py @@ -375,8 +375,11 @@ async def test_output_rails_only_parallel_with_exceptions(): }, } - chat >> "Hello" - result = await chat.app.generate_async(messages=chat.history, options=options_output_only) + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "This response contains harmful content"}, + ] + result = await chat.app.generate_async(messages=messages, options=options_output_only) input_rails = [r for r in result.log.activated_rails if r.type == "input"] output_rails = [r for r in result.log.activated_rails if r.type == "output"]