Skip to content

Commit 343b759

Browse files
andompestaSandro Cavallari
andauthored
fix: Prevent explain_info overwrite during stream_async (#1194)
* add `_ensure_explain_info` function to solve explain_info_var context between `stream_async` and `generate_async` * apply pre-commit changes --------- Co-authored-by: Sandro Cavallari <scavallari@gmail.com>
1 parent f400e6f commit 343b759

File tree

2 files changed

+139
-19
lines changed

2 files changed

+139
-19
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
generation_options_var,
5252
llm_stats_var,
5353
raw_llm_request,
54-
reasoning_trace_var,
5554
streaming_handler_var,
5655
)
5756
from nemoguardrails.embeddings.index import EmbeddingsIndex
@@ -576,6 +575,20 @@ def _get_events_for_messages(self, messages: List[dict], state: Any):
576575

577576
return events
578577

578+
@staticmethod
579+
def _ensure_explain_info() -> ExplainInfo:
580+
"""Ensure that the ExplainInfo variable is present in the current context
581+
582+
Returns:
583+
A ExplainInfo class containing the llm calls' statistics
584+
"""
585+
explain_info = explain_info_var.get()
586+
if explain_info is None:
587+
explain_info = ExplainInfo()
588+
explain_info_var.set(explain_info)
589+
590+
return explain_info
591+
579592
async def generate_async(
580593
self,
581594
prompt: Optional[str] = None,
@@ -634,14 +647,7 @@ async def generate_async(
634647
# Initialize the object with additional explanation information.
635648
# We allow this to also be set externally. This is useful when multiple parallel
636649
# requests are made.
637-
explain_info = explain_info_var.get()
638-
if explain_info is None:
639-
explain_info = ExplainInfo()
640-
explain_info_var.set(explain_info)
641-
642-
# We also keep a general reference to this object
643-
self.explain_info = explain_info
644-
self.explain_info = explain_info
650+
self.explain_info = self._ensure_explain_info()
645651

646652
if prompt is not None:
647653
# Currently, we transform the prompt request into a single turn conversation
@@ -805,9 +811,11 @@ async def generate_async(
805811

806812
# If logging is enabled, we log the conversation
807813
# TODO: add support for logging flag
808-
explain_info.colang_history = get_colang_history(events)
814+
self.explain_info.colang_history = get_colang_history(events)
809815
if self.verbose:
810-
log.info(f"Conversation history so far: \n{explain_info.colang_history}")
816+
log.info(
817+
f"Conversation history so far: \n{self.explain_info.colang_history}"
818+
)
811819

812820
total_time = time.time() - t0
813821
log.info(
@@ -960,6 +968,8 @@ def stream_async(
960968
include_generation_metadata: Optional[bool] = False,
961969
) -> AsyncIterator[str]:
962970
"""Simplified interface for getting directly the streamed tokens from the LLM."""
971+
self.explain_info = self._ensure_explain_info()
972+
963973
streaming_handler = StreamingHandler(
964974
include_generation_metadata=include_generation_metadata
965975
)
@@ -1278,13 +1288,6 @@ def _prepare_params(
12781288
**action_params,
12791289
}
12801290

1281-
def _update_explain_info():
1282-
explain_info = explain_info_var.get()
1283-
if explain_info is None:
1284-
explain_info = ExplainInfo()
1285-
explain_info_var.set(explain_info)
1286-
self.explain_info = explain_info
1287-
12881291
output_rails_streaming_config = self.config.rails.output.streaming
12891292
buffer_strategy = get_buffer_strategy(output_rails_streaming_config)
12901293
output_rails_flows_id = self.config.rails.output.flows
@@ -1329,7 +1332,7 @@ def _update_explain_info():
13291332
action_name, params
13301333
)
13311334
# Include explain info (whatever _update_explain_info does)
1332-
_update_explain_info()
1335+
self.explain_info = self._ensure_explain_info()
13331336

13341337
# Retrieve the action function from the dispatcher
13351338
action_func = self.runtime.action_dispatcher.get_action(action_name)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import asyncio
16+
17+
import pytest
18+
19+
from nemoguardrails import RailsConfig
20+
from tests.utils import TestChat
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_1():
25+
config = RailsConfig.from_content(
26+
"""
27+
define user express greeting
28+
"hello"
29+
30+
define flow
31+
user express greeting
32+
bot express greeting
33+
"""
34+
)
35+
chat = TestChat(
36+
config,
37+
llm_completions=[
38+
"express greeting",
39+
"Hello! I'm doing great, thank you. How can I assist you today?",
40+
],
41+
)
42+
43+
new_messages = await chat.app.generate_async(
44+
messages=[{"role": "user", "content": "hi, how are you"}]
45+
)
46+
47+
assert new_messages == {
48+
"content": "Hello! I'm doing great, thank you. How can I assist you today?",
49+
"role": "assistant",
50+
}, "message content do not match"
51+
52+
# note that 2 llm call are expected as we matched the bot intent
53+
assert (
54+
len(chat.app.explain().llm_calls) == 2
55+
), "number of llm call not as expected. Expected 2, found {}".format(
56+
len(chat.app.explain().llm_calls)
57+
)
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_2():
62+
config = RailsConfig.from_content(
63+
config={
64+
"models": [],
65+
"rails": {
66+
"output": {
67+
# run the real self check output rails
68+
"flows": {"self check output"},
69+
"streaming": {
70+
"enabled": True,
71+
"chunk_size": 4,
72+
"context_size": 2,
73+
"stream_first": False,
74+
},
75+
}
76+
},
77+
"streaming": False,
78+
"prompts": [{"task": "self_check_output", "content": "a test template"}],
79+
},
80+
colang_content="""
81+
define user express greeting
82+
"hi"
83+
84+
define flow
85+
user express greeting
86+
bot tell joke
87+
""",
88+
)
89+
90+
llm_completions = [
91+
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
92+
' "This is a joke that should be blocked."',
93+
# add as many `no`` as chunks you want the output stream to check
94+
"No",
95+
"No",
96+
"Yes",
97+
]
98+
99+
chat = TestChat(
100+
config,
101+
llm_completions=llm_completions,
102+
streaming=True,
103+
)
104+
chunks = []
105+
async for chunk in chat.app.stream_async(
106+
messages=[{"role": "user", "content": "Hi!"}],
107+
):
108+
chunks.append(chunk)
109+
110+
# note that 6 llm call are expected as we matched the bot intent
111+
assert (
112+
len(chat.app.explain().llm_calls) == 5
113+
), "number of llm call not as expected. Expected 5, found {}".format(
114+
len(chat.app.explain().llm_calls)
115+
)
116+
117+
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})

0 commit comments

Comments
 (0)