Skip to content

Commit ffe58ac

Browse files
committed
Add first inference function tests
1 parent a3dc256 commit ffe58ac

File tree

1 file changed

+41
-3
lines changed

1 file changed

+41
-3
lines changed

web-apps/chat/test.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import openai
12
import os
23
import unittest
3-
from unittest.mock import patch
44

5-
# from unittest import mock
65
from gradio_client import Client
7-
from app import build_chat_context, SystemMessage, HumanMessage, AIMessage
6+
from unittest.mock import patch, MagicMock
7+
from langchain.schema import HumanMessage, AIMessage, SystemMessage
8+
from app import build_chat_context, inference, PossibleSystemPromptException, BACKEND_INITIALISED
89

910
url = os.environ.get("GRADIO_URL", "http://localhost:7860")
1011
client = Client(url)
@@ -58,5 +59,42 @@ def test_chat_context_human_prompt(self, mock_settings):
5859
self.assertIsInstance(context[2], HumanMessage)
5960
self.assertEqual(context[2].content, latest_message)
6061

62+
# inference function tests
63+
@patch("app.settings")
64+
@patch("app.llm")
65+
@patch("app.log")
66+
def test_inference_success(self, mock_logger, mock_llm, mock_settings):
67+
mock_llm.stream.return_value = [MagicMock(content="response_chunk")]
68+
69+
mock_settings.model_instruction = "You are a very helpful assistant."
70+
latest_message = "Why don't we drink horse milk?"
71+
history = [
72+
{"role": "user", 'metadata': None, "content": "Hi there!", 'options': None},
73+
{"role": "assistant", 'metadata': None, "content": "Hi! How can I help you?", 'options': None},
74+
]
75+
76+
responses = list(inference(latest_message, history))
77+
78+
self.assertEqual(responses, ["response_chunk"])
79+
mock_logger.debug.assert_any_call("Inference request received with history: %s", history)
80+
81+
@patch("app.llm")
82+
@patch("app.build_chat_context")
83+
def test_inference_thinking_tags(self, mock_build_chat_context, mock_llm):
84+
mock_build_chat_context.return_value = ["mock_context"]
85+
mock_llm.stream.return_value = [
86+
MagicMock(content="<think>"),
87+
MagicMock(content="processing"),
88+
MagicMock(content="</think>"),
89+
MagicMock(content="final response"),
90+
]
91+
latest_message = "Hello"
92+
history = []
93+
94+
responses = list(inference(latest_message, history))
95+
96+
self.assertEqual(responses, ["Thinking...", "Thinking...", "", "final response"])
97+
98+
6199
if __name__ == "__main__":
62100
unittest.main()

0 commit comments

Comments
 (0)