|
| 1 | +import openai |
1 | 2 | import os
|
2 | 3 | import unittest
|
3 |
| -from unittest.mock import patch |
4 | 4 |
|
5 |
| -# from unittest import mock |
6 | 5 | from gradio_client import Client
|
7 |
| -from app import build_chat_context, SystemMessage, HumanMessage, AIMessage |
| 6 | +from unittest.mock import patch, MagicMock |
| 7 | +from langchain.schema import HumanMessage, AIMessage, SystemMessage |
| 8 | +from app import build_chat_context, inference, PossibleSystemPromptException, BACKEND_INITIALISED |
8 | 9 |
|
9 | 10 | url = os.environ.get("GRADIO_URL", "http://localhost:7860")
|
10 | 11 | client = Client(url)
|
@@ -58,5 +59,42 @@ def test_chat_context_human_prompt(self, mock_settings):
|
58 | 59 | self.assertIsInstance(context[2], HumanMessage)
|
59 | 60 | self.assertEqual(context[2].content, latest_message)
|
60 | 61 |
|
| 62 | + # inference function tests |
| 63 | + @patch("app.settings") |
| 64 | + @patch("app.llm") |
| 65 | + @patch("app.log") |
| 66 | + def test_inference_success(self, mock_logger, mock_llm, mock_settings): |
| 67 | + mock_llm.stream.return_value = [MagicMock(content="response_chunk")] |
| 68 | + |
| 69 | + mock_settings.model_instruction = "You are a very helpful assistant." |
| 70 | + latest_message = "Why don't we drink horse milk?" |
| 71 | + history = [ |
| 72 | + {"role": "user", 'metadata': None, "content": "Hi there!", 'options': None}, |
| 73 | + {"role": "assistant", 'metadata': None, "content": "Hi! How can I help you?", 'options': None}, |
| 74 | + ] |
| 75 | + |
| 76 | + responses = list(inference(latest_message, history)) |
| 77 | + |
| 78 | + self.assertEqual(responses, ["response_chunk"]) |
| 79 | + mock_logger.debug.assert_any_call("Inference request received with history: %s", history) |
| 80 | + |
| 81 | + @patch("app.llm") |
| 82 | + @patch("app.build_chat_context") |
| 83 | + def test_inference_thinking_tags(self, mock_build_chat_context, mock_llm): |
| 84 | + mock_build_chat_context.return_value = ["mock_context"] |
| 85 | + mock_llm.stream.return_value = [ |
| 86 | + MagicMock(content="<think>"), |
| 87 | + MagicMock(content="processing"), |
| 88 | + MagicMock(content="</think>"), |
| 89 | + MagicMock(content="final response"), |
| 90 | + ] |
| 91 | + latest_message = "Hello" |
| 92 | + history = [] |
| 93 | + |
| 94 | + responses = list(inference(latest_message, history)) |
| 95 | + |
| 96 | + self.assertEqual(responses, ["Thinking...", "Thinking...", "", "final response"]) |
| 97 | + |
| 98 | + |
61 | 99 | if __name__ == "__main__":
|
62 | 100 | unittest.main()
|
0 commit comments