diff --git a/notebooks/agents/langchain_agent_simple_demo.ipynb b/notebooks/agents/langchain_agent_simple_demo.ipynb index cd804d468..c3658a07e 100644 --- a/notebooks/agents/langchain_agent_simple_demo.ipynb +++ b/notebooks/agents/langchain_agent_simple_demo.ipynb @@ -57,12 +57,10 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import List, Optional, Dict, Any\n", + "from typing import Optional, Dict, Any\n", "from langchain.tools import tool\n", - "from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage\n", + "from langchain_core.messages import HumanMessage, SystemMessage\n", "from langchain_openai import ChatOpenAI\n", - "import json\n", - "import pandas as pd\n", "\n", "# Load environment variables if using .env file\n", "try:\n", @@ -88,11 +86,31 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "## LLM-Powered Tool Selection Router\n", + "\n", + "This section demonstrates how to create an intelligent router that uses an LLM to select the most appropriate tool based on user input and tool docstrings.\n", + "\n", + "### Benefits of LLM-Based Tool Selection:\n", + "- **Intelligent Routing**: Understanding of natural language intent\n", + "- **Dynamic Selection**: Can handle complex, multi-step requests \n", + "- **Context Awareness**: Considers conversation history and context\n", + "- **Flexible Matching**: Not limited to keyword patterns\n", + "- **Tool Documentation**: Uses actual tool docstrings for decision making\n" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Simple Tools with Rich Docstrings\n", + "## Simplified Tools with Rich Docstrings\n", "\n", "We've simplified the agent to use only two core tools:\n", "- **search_engine**: For searching through documents, policies, and knowledge base \n", @@ -208,7 +226,7 @@ " Create a timeline, 4) Start with the most critical parts, 5) Review and adjust as needed.\n", " \"\"\"\n", "\n", - "# Collect all tools for the LLM - SIMPLIFIED TO ONLY 2 TOOLS\n", + "# Collect all tools for the LLM router - SIMPLIFIED TO ONLY 2 TOOLS\n", "AVAILABLE_TOOLS = [\n", " search_engine,\n", " task_assistant\n", @@ -233,7 +251,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def create_intelligent_langchain_agent():\n", " \"\"\"Create a simplified LangChain agent with direct tool calling.\"\"\"\n", " \n", @@ -251,7 +268,7 @@ " - Use for: finding company policies, technical documentation, compliance documents\n", " - Examples: \"Find our data privacy policy\", \"Search for API documentation\"\n", "\n", - " 🎯 **task_assistant** - General-purpose task assistance and problem-solving \n", + " **task_assistant** - General-purpose task assistance and problem-solving \n", " - Use for: guidance, recommendations, explaining concepts, planning activities\n", " - Examples: \"How to prepare for an interview\", \"Help plan a meeting\", \"Explain machine learning\"\n", "\n", @@ -278,7 +295,7 @@ " # Get initial response from LLM\n", " response = llm_with_tools.invoke(messages)\n", " messages.append(response)\n", - " \n", + " tools_used = []\n", " # Check if the LLM wants to use tools\n", " if hasattr(response, 'tool_calls') and response.tool_calls:\n", " # Execute tool calls\n", @@ -288,11 +305,13 @@ " for tool in AVAILABLE_TOOLS:\n", " if tool.name == tool_call['name']:\n", " tool_to_call = tool\n", + " tools_used.append(tool_to_call.name)\n", " break\n", " \n", " if tool_to_call:\n", " # Execute the tool\n", " try:\n", + "\n", " tool_result = tool_to_call.invoke(tool_call['args'])\n", " # Add tool message to conversation\n", " from langchain_core.messages import ToolMessage\n", @@ -314,7 +333,8 @@ " \"messages\": messages,\n", " \"user_input\": user_input,\n", " \"session_id\": session_id,\n", - " \"context\": {}\n", + " \"context\": {},\n", + " \"tools_used\": tools_used\n", " }\n", " \n", " return invoke_agent\n", @@ -369,7 +389,7 @@ " # Invoke the agent with the user input\n", " result = intelligent_agent(user_input, session_id)\n", " \n", - " return result\n", + " return {\"prediction\": result['messages'][-1].content, \"output\": result, \"tools_used\": result['tools_used']}\n", "\n", "\n", "vm_intelligent_model = vm.init_model(input_id=\"financial_model\", predict_fn=agent_fn)\n", @@ -382,11 +402,25 @@ "metadata": {}, "source": [ "## Prepare Sample Test Dataset\n", - "Now we'll create a test dataset to validate our agent's behavior. This dataset includes:\n", - "- Various user queries that test different agent capabilities\n", - "- Expected tools that should be used for each query\n", - "- Possible valid outputs for each query\n", - "- Unique session IDs to track conversation threads" + "\n", + "We'll create a comprehensive test dataset to evaluate our agent's performance across different scenarios. This dataset includes:\n", + "\n", + "**Diverse Test Cases**: Various types of user requests that test different agent capabilities:\n", + "- **Single Tool Requests**: Simple queries that require one specific tool\n", + "- **Multi-Tool Requests**: Complex queries requiring multiple tools in sequence \n", + "- **Validation Tasks**: Requests for data validation and verification\n", + "- **General Assistance**: Open-ended questions for problem-solving guidance\n", + "\n", + "**Expected Outputs**: For each test case, we define:\n", + "- **Expected Tools**: Which tools should be selected by the router\n", + "- **Possible Outputs**: Valid response patterns or values\n", + "- **Session IDs**: Unique identifiers for conversation tracking\n", + "\n", + "**Test Coverage**: The dataset covers:\n", + "- Document retrieval (search_engine tool)\n", + "- General guidance (task_assistant tool)\n", + "\n", + "This structured approach allows us to systematically evaluate both tool selection accuracy and response quality." ] }, { @@ -403,13 +437,13 @@ " {\n", " \"input\": \"Find our company's data privacy policy\",\n", " \"expected_tools\": [\"search_engine\"],\n", - " \"possible_outputs\": [\"privacy_policy\", \"data_protection\", \"company_privacy_guidelines\"],\n", + " \"possible_outputs\": [\"privacy_policy.pdf\", \"data_protection.doc\", \"company_privacy_guidelines.txt\"],\n", " \"session_id\": str(uuid.uuid4())\n", " },\n", " {\n", " \"input\": \"Search for loan approval procedures\", \n", " \"expected_tools\": [\"search_engine\"],\n", - " \"possible_outputs\": [\"loan_procedures\", \"approval_process\", \"lending_guidelines\"],\n", + " \"possible_outputs\": [\"loan_procedures.doc\", \"approval_process.pdf\", \"lending_guidelines.txt\"],\n", " \"session_id\": str(uuid.uuid4())\n", " },\n", " {\n", @@ -433,7 +467,7 @@ " {\n", " \"input\": \"Find technical documentation about API endpoints\",\n", " \"expected_tools\": [\"search_engine\"],\n", - " \"possible_outputs\": [\"API_documentation\", \"REST_endpoints\", \"technical\"],\n", + " \"possible_outputs\": [\"API_documentation.pdf\", \"REST_endpoints.doc\", \"technical_guide.txt\"],\n", " \"session_id\": str(uuid.uuid4())\n", " },\n", " {\n", @@ -549,27 +583,6 @@ "vm_test_dataset._df" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Agent prediction column adjustment in dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output = vm_test_dataset._df['financial_model_prediction']\n", - "predictions = [row['messages'][-1].content for row in output]\n", - "\n", - "vm_test_dataset._df['output'] = output\n", - "vm_test_dataset._df['financial_model_prediction'] = predictions\n", - "vm_test_dataset._df.head(2)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -722,21 +735,24 @@ "source": [ "## Tool Call Accuracy Test\n", "\n", - "This test evaluates how accurately our LangChain agent calls the expected tools for different inputs. It's a validation step that measures:\n", + "This test evaluates how accurately our intelligent router selects the correct tools for different user requests. It's a critical validation step that measures:\n", + "\n", + "**Tool Selection Performance**: Analyzes whether the agent correctly identifies and calls the expected tools\n", + "- **Expected vs. Actual**: Compares tools that should be called with tools that were actually called\n", + "- **Accuracy Scoring**: Calculates percentage accuracy for tool selection decisions\n", + "- **Multi-tool Handling**: Evaluates performance on requests requiring multiple tools\n", "\n", - "**Tool Call Validation**: Analyzes the actual tool calls made by the agent\n", - "- **Tool Call Extraction**: Extracts tool calls from agent message history\n", - "- **Expected vs Actual Comparison**: Compares expected tools against tools that were actually called\n", - "- **Accuracy Calculation**: Computes accuracy as ratio of matched tools to expected tools\n", - "- **Multi-tool Support**: Handles cases with multiple expected tool calls\n", + "**Router Intelligence Assessment**: Validates the LLM-powered routing system's effectiveness\n", + "- **Intent Recognition**: How well the router understands user intent from natural language\n", + "- **Tool Mapping**: Accuracy of mapping user needs to appropriate tool capabilities\n", + "- **Decision Quality**: Assessment of routing confidence and reasoning\n", "\n", - "**Results Analysis**: For each test case, provides:\n", - "- **Accuracy Score**: Percentage of expected tools that were correctly called\n", - "- **Expected Tools**: List of tools that should have been called\n", - "- **Found Tools**: List of tools that were actually called\n", - "- **Match Details**: Number of matches and total expected tools\n", + "**Failure Analysis**: Identifies patterns in incorrect tool selections to improve the routing logic\n", + "- **Missed Tools**: Cases where expected tools weren't selected\n", + "- **Extra Tools**: Cases where unnecessary tools were selected \n", + "- **Wrong Tools**: Cases where completely incorrect tools were selected\n", "\n", - "The test processes a dataset containing agent outputs and expected tool lists, providing quantitative feedback on the agent's ability to call the right tools for each input." + "This test provides quantitative feedback on the agent's core intelligence - its ability to understand what users need and select the right tools to help them." ] }, { @@ -854,7 +870,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_utils import capture_tool_output_messages, extract_tool_results_only, get_final_agent_response, format_tool_outputs_for_display\n", + "from notebooks.agents.langchain_utils import capture_tool_output_messages\n", "\n", "tool_messages = []\n", "for i, row in vm_test_dataset._df.iterrows():\n", @@ -863,22 +879,10 @@ " result = row['output']\n", " # Capture all tool outputs and metadata\n", " captured_data = capture_tool_output_messages(result)\n", - "\n", - " # Get just the tool results in a simple format\n", - " tool_results = extract_tool_results_only(result)\n", - "\n", - " # Get the final agent response\n", - " final_response = get_final_agent_response(result)\n", - "\n", - " # Print formatted summary\n", - " # print(format_tool_outputs_for_display(captured_data))\n", - "\n", + " \n", " # Access specific tool outputs\n", " for output in captured_data[\"tool_outputs\"]:\n", - " # print(f\"Tool: {output['tool_name']}\")\n", - " # print(f\"Output: {output['content']}\")\n", " tool_message += output['content']\n", - " # print(\"-\" * 30)\n", " tool_messages.append([tool_message])\n", "\n", "vm_test_dataset._df['tool_messages'] = tool_messages" diff --git a/notebooks/agents/langchain_utils.py b/notebooks/agents/langchain_utils.py index c0206ac90..e10954f28 100644 --- a/notebooks/agents/langchain_utils.py +++ b/notebooks/agents/langchain_utils.py @@ -1,20 +1,19 @@ -from typing import Dict, List, Any -from langchain_core.messages import ToolMessage, AIMessage +from typing import Dict, Any +from langchain_core.messages import ToolMessage def capture_tool_output_messages(agent_result: Dict[str, Any]) -> Dict[str, Any]: """ Capture all tool outputs and metadata from agent results. - + Args: agent_result: The result from the LangChain agent execution - Returns: Dictionary containing tool outputs and metadata """ messages = agent_result.get('messages', []) tool_outputs = [] - + for message in messages: if isinstance(message, ToolMessage): tool_outputs.append({ @@ -22,71 +21,9 @@ def capture_tool_output_messages(agent_result: Dict[str, Any]) -> Dict[str, Any] 'content': message.content, 'tool_call_id': getattr(message, 'tool_call_id', None) }) - + return { 'tool_outputs': tool_outputs, 'total_messages': len(messages), 'tool_message_count': len(tool_outputs) } - - -def extract_tool_results_only(agent_result: Dict[str, Any]) -> List[str]: - """ - Extract just the tool results in a simple format. - - Args: - agent_result: The result from the LangChain agent execution - - Returns: - List of tool result strings - """ - messages = agent_result.get('messages', []) - tool_results = [] - - for message in messages: - if isinstance(message, ToolMessage): - tool_results.append(message.content) - - return tool_results - - -def get_final_agent_response(agent_result: Dict[str, Any]) -> str: - """ - Get the final agent response from the conversation. - - Args: - agent_result: The result from the LangChain agent execution - - Returns: - The final response content as a string - """ - messages = agent_result.get('messages', []) - - # Look for the last AI message - for message in reversed(messages): - if isinstance(message, AIMessage): - return message.content - - return "No final response found" - - -def format_tool_outputs_for_display(captured_data: Dict[str, Any]) -> str: - """ - Format tool outputs for readable display. - - Args: - captured_data: Data from capture_tool_output_messages - - Returns: - Formatted string for display - """ - output = "Tool Execution Summary:\n" - output += f"Total messages: {captured_data['total_messages']}\n" - output += f"Tool messages: {captured_data['tool_message_count']}\n\n" - - for i, tool_output in enumerate(captured_data['tool_outputs'], 1): - output += f"Tool {i}: {tool_output['tool_name']}\n" - output += f"Output: {tool_output['content']}\n" - output += "-" * 30 + "\n" - - return output diff --git a/notebooks/agents/langgraph_agent_demo.ipynb b/notebooks/agents/langgraph_agent_demo.ipynb index c6df56514..009369840 100644 --- a/notebooks/agents/langgraph_agent_demo.ipynb +++ b/notebooks/agents/langgraph_agent_demo.ipynb @@ -816,7 +816,7 @@ "\n", " result = intelligent_agent.invoke(initial_state, config=session_config)\n", "\n", - " return result\n", + " return {\"prediction\": result['messages'][-1].content, \"output\": result, \"tools_used\": result['selected_tools']}\n", "\n", "\n", "vm_intelligent_model = vm.init_model(input_id=\"financial_model\", predict_fn=agent_fn)\n", @@ -1014,27 +1014,6 @@ "vm_test_dataset._df" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Agent prediction column adjustment in dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output = vm_test_dataset._df['financial_model_prediction']\n", - "predictions = [row['messages'][-1].content for row in output]\n", - "\n", - "vm_test_dataset._df['output'] = output\n", - "vm_test_dataset._df['financial_model_prediction'] = predictions\n", - "vm_test_dataset._df.head(2)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -1306,31 +1285,18 @@ "metadata": {}, "outputs": [], "source": [ - "from utils import capture_tool_output_messages, extract_tool_results_only, get_final_agent_response, format_tool_outputs_for_display\n", + "from notebooks.agents.utils import capture_tool_output_messages#, #extract_tool_results_only, get_final_agent_response, format_tool_outputs_for_display\n", "\n", "tool_messages = []\n", "for i, row in vm_test_dataset._df.iterrows():\n", " tool_message = \"\"\n", - " # Print messages in a readable format\n", " result = row['output']\n", " # Capture all tool outputs and metadata\n", " captured_data = capture_tool_output_messages(result)\n", "\n", - " # Get just the tool results in a simple format\n", - " tool_results = extract_tool_results_only(result)\n", - "\n", - " # Get the final agent response\n", - " final_response = get_final_agent_response(result)\n", - "\n", - " # Print formatted summary\n", - " # print(format_tool_outputs_for_display(captured_data))\n", - "\n", " # Access specific tool outputs\n", " for output in captured_data[\"tool_outputs\"]:\n", - " # print(f\"Tool: {output['tool_name']}\")\n", - " # print(f\"Output: {output['content']}\")\n", " tool_message += output['content']\n", - " # print(\"-\" * 30)\n", " tool_messages.append([tool_message])\n", "\n", "vm_test_dataset._df['tool_messages'] = tool_messages" diff --git a/notebooks/agents/langgraph_agent_simple_demo.ipynb b/notebooks/agents/langgraph_agent_simple_demo.ipynb index 2a45621b2..24260c68b 100644 --- a/notebooks/agents/langgraph_agent_simple_demo.ipynb +++ b/notebooks/agents/langgraph_agent_simple_demo.ipynb @@ -388,7 +388,7 @@ "\n", " result = intelligent_agent.invoke(initial_state, config=session_config)\n", "\n", - " return result\n", + " return {\"prediction\": result['messages'][-1].content, \"output\": result}\n", "\n", "\n", "vm_intelligent_model = vm.init_model(input_id=\"financial_model\", predict_fn=agent_fn)\n", @@ -396,15 +396,6 @@ "vm_intelligent_model.model = intelligent_agent" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vm_intelligent_model.model" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -562,27 +553,6 @@ "vm_test_dataset._df" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Agent prediction column adjustment in dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output = vm_test_dataset._df['financial_model_prediction']\n", - "predictions = [row['messages'][-1].content for row in output]\n", - "\n", - "vm_test_dataset._df['output'] = output\n", - "vm_test_dataset._df['financial_model_prediction'] = predictions\n", - "vm_test_dataset._df.head(2)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -832,31 +802,18 @@ "metadata": {}, "outputs": [], "source": [ - "from utils import capture_tool_output_messages, extract_tool_results_only, get_final_agent_response, format_tool_outputs_for_display\n", + "from utils import capture_tool_output_messages\n", "\n", "tool_messages = []\n", "for i, row in vm_test_dataset._df.iterrows():\n", " tool_message = \"\"\n", - " # Print messages in a readable format\n", " result = row['output']\n", " # Capture all tool outputs and metadata\n", " captured_data = capture_tool_output_messages(result)\n", - "\n", - " # Get just the tool results in a simple format\n", - " tool_results = extract_tool_results_only(result)\n", - "\n", - " # Get the final agent response\n", - " final_response = get_final_agent_response(result)\n", - "\n", - " # Print formatted summary\n", - " # print(format_tool_outputs_for_display(captured_data))\n", - "\n", + " \n", " # Access specific tool outputs\n", " for output in captured_data[\"tool_outputs\"]:\n", - " # print(f\"Tool: {output['tool_name']}\")\n", - " # print(f\"Output: {output['content']}\")\n", " tool_message += output['content']\n", - " # print(\"-\" * 30)\n", " tool_messages.append([tool_message])\n", "\n", "vm_test_dataset._df['tool_messages'] = tool_messages" diff --git a/notebooks/agents/utils.py b/notebooks/agents/utils.py index 3fc807327..aad0e2f3e 100644 --- a/notebooks/agents/utils.py +++ b/notebooks/agents/utils.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Any, Optional +from typing import Dict, Any from langchain_core.messages import ToolMessage, AIMessage, HumanMessage @@ -102,100 +102,3 @@ def capture_tool_output_messages(result: Dict[str, Any]) -> Dict[str, Any]: } return captured_data - - -def extract_tool_results_only(result: Dict[str, Any]) -> List[Dict[str, str]]: - """ - Extract only the tool results/outputs in a simplified format. - - Args: - result: The result dictionary from a LangGraph agent execution - - Returns: - List of dictionaries with tool name and output content - """ - tool_results = [] - messages = result.get("messages", []) - - for message in messages: - if isinstance(message, ToolMessage): - tool_results.append({ - "tool_name": getattr(message, 'name', 'unknown'), - "output": message.content, - "tool_call_id": getattr(message, 'tool_call_id', None) - }) - - return tool_results - - -def get_final_agent_response(result: Dict[str, Any]) -> Optional[str]: - """ - Get the final response from the agent (last AI message). - - Args: - result: The result dictionary from a LangGraph agent execution - - Returns: - The content of the final AI message, or None if not found - """ - messages = result.get("messages", []) - - # Find the last AI message - for message in reversed(messages): - if isinstance(message, AIMessage) and message.content: - return message.content - - return None - - -def format_tool_outputs_for_display(captured_data: Dict[str, Any]) -> str: - """ - Format tool outputs in a readable string format. - - Args: - captured_data: Result from capture_tool_output_messages() - - Returns: - Formatted string representation of tool outputs - """ - output_lines = [] - output_lines.append("🔧 TOOL OUTPUTS SUMMARY") - output_lines.append("=" * 40) - - summary = captured_data["execution_summary"] - output_lines.append(f"Total tools used: {len(summary['tools_used'])}") - output_lines.append(f"Tools: {', '.join(summary['tools_used'])}") - output_lines.append(f"Tool calls: {summary['tool_calls_count']}") - output_lines.append(f"Tool outputs: {summary['tool_outputs_count']}") - output_lines.append("") - - for i, output in enumerate(captured_data["tool_outputs"], 1): - output_lines.append(f"{i}. {output['tool_name'].upper()}") - output_lines.append(f" Output: {output['content'][:100]}{'...' if len(output['content']) > 100 else ''}") - output_lines.append("") - - return "\n".join(output_lines) - - -# Example usage functions -def demo_capture_usage(agent_result): - """Demonstrate how to use the capture functions.""" - - # Capture all tool outputs and metadata - captured = capture_tool_output_messages(agent_result) - - # Get just the tool results - tool_results = extract_tool_results_only(agent_result) - - # Get the final agent response - final_response = get_final_agent_response(agent_result) - - # Format for display - formatted_output = format_tool_outputs_for_display(captured) - - return { - "full_capture": captured, - "tool_results_only": tool_results, - "final_response": final_response, - "formatted_display": formatted_output - } diff --git a/tests/test_dataset.py b/tests/test_dataset.py index e18a90aa4..41bc40fc8 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -303,6 +303,219 @@ def test_assign_predictions_with_no_model_and_prediction_values(self): # Probabilities are not auto-assigned if prediction_values are provided self.assertTrue("logreg_probabilities" not in vm_dataset._df.columns) + def test_assign_predictions_with_classification_predict_fn(self): + """ + Test assigning predictions to dataset with a model created using predict_fn for classification + """ + df = pd.DataFrame({"x1": [1, 2, 3], "x2": [4, 5, 6], "y": [0, 1, 0]}) + vm_dataset = DataFrameDataset( + raw_dataset=df, target_column="y", feature_columns=["x1", "x2"] + ) + + # Define a simple classification predict function + def simple_classify_fn(input_dict): + # Simple rule: if x1 + x2 > 5, return 1, else 0 + return 1 if input_dict["x1"] + input_dict["x2"] > 5 else 0 + + vm_model = init_model( + input_id="predict_fn_classifier", predict_fn=simple_classify_fn, __log=False + ) + self.assertIsNone(vm_dataset.prediction_column(vm_model)) + + vm_dataset.assign_predictions(model=vm_model) + self.assertEqual( + vm_dataset.prediction_column(vm_model), "predict_fn_classifier_prediction" + ) + + # Check that the predictions are assigned to the dataset + self.assertTrue("predict_fn_classifier_prediction" in vm_dataset._df.columns) + self.assertIsInstance(vm_dataset.y_pred(vm_model), np.ndarray) + self.assertIsInstance(vm_dataset.y_pred_df(vm_model), pd.DataFrame) + + # Verify the actual predictions match our function logic + expected_predictions = [0, 1, 1] # [1+4=5 -> 0, 2+5=7 -> 1, 3+6=9 -> 1] + np.testing.assert_array_equal(vm_dataset.y_pred(vm_model), expected_predictions) + + def test_assign_predictions_with_regression_predict_fn(self): + """ + Test assigning predictions to dataset with a model created using predict_fn for regression + """ + df = pd.DataFrame({"x1": [1, 2, 3], "x2": [4, 5, 6], "y": [0.1, 1.2, 2.3]}) + vm_dataset = DataFrameDataset( + raw_dataset=df, target_column="y", feature_columns=["x1", "x2"] + ) + + # Define a simple regression predict function + def simple_regression_fn(input_dict): + # Simple linear combination: x1 * 0.5 + x2 * 0.3 + return input_dict["x1"] * 0.5 + input_dict["x2"] * 0.3 + + vm_model = init_model( + input_id="predict_fn_regressor", predict_fn=simple_regression_fn, __log=False + ) + self.assertIsNone(vm_dataset.prediction_column(vm_model)) + + vm_dataset.assign_predictions(model=vm_model) + self.assertEqual( + vm_dataset.prediction_column(vm_model), "predict_fn_regressor_prediction" + ) + + # Check that the predictions are assigned to the dataset + self.assertTrue("predict_fn_regressor_prediction" in vm_dataset._df.columns) + self.assertIsInstance(vm_dataset.y_pred(vm_model), np.ndarray) + self.assertIsInstance(vm_dataset.y_pred_df(vm_model), pd.DataFrame) + + # Verify the actual predictions match our function logic + expected_predictions = [ + 1 * 0.5 + 4 * 0.3, # 0.5 + 1.2 = 1.7 + 2 * 0.5 + 5 * 0.3, # 1.0 + 1.5 = 2.5 + 3 * 0.5 + 6 * 0.3, # 1.5 + 1.8 = 3.3 + ] + np.testing.assert_array_almost_equal( + vm_dataset.y_pred(vm_model), expected_predictions + ) + + def test_assign_predictions_with_complex_predict_fn(self): + """ + Test assigning predictions to dataset with a predict_fn that returns complex outputs + """ + df = pd.DataFrame({"x1": [1, 2, 3], "x2": [4, 5, 6], "y": [0, 1, 0]}) + vm_dataset = DataFrameDataset( + raw_dataset=df, target_column="y", feature_columns=["x1", "x2"] + ) + + # Define a predict function that returns a dictionary + def complex_predict_fn(input_dict): + prediction = 1 if input_dict["x1"] + input_dict["x2"] > 5 else 0 + confidence = abs(input_dict["x1"] - input_dict["x2"]) / 10.0 + return { + "prediction": prediction, + "confidence": confidence, + "feature_sum": input_dict["x1"] + input_dict["x2"], + } + + vm_model = init_model( + input_id="complex_predict_fn", predict_fn=complex_predict_fn, __log=False + ) + + vm_dataset.assign_predictions(model=vm_model) + self.assertEqual( + vm_dataset.prediction_column(vm_model), "complex_predict_fn_prediction" + ) + + # Check that the predictions and other columns are assigned to the dataset + self.assertTrue("complex_predict_fn_prediction" in vm_dataset._df.columns) + self.assertTrue("complex_predict_fn_confidence" in vm_dataset._df.columns) + self.assertTrue("complex_predict_fn_feature_sum" in vm_dataset._df.columns) + + # Verify the prediction values (extracted from "prediction" key in dict) + predictions = vm_dataset.y_pred(vm_model) + expected_predictions = [0, 1, 1] # [1+4=5 -> 0, 2+5=7 -> 1, 3+6=9 -> 1] + np.testing.assert_array_equal(predictions, expected_predictions) + + # Verify other dictionary keys were added as separate columns + confidence_values = vm_dataset._df["complex_predict_fn_confidence"].values + expected_confidence = [0.3, 0.3, 0.3] # |1-4|/10, |2-5|/10, |3-6|/10 + np.testing.assert_array_almost_equal(confidence_values, expected_confidence) + + feature_sum_values = vm_dataset._df["complex_predict_fn_feature_sum"].values + expected_feature_sums = [5, 7, 9] # 1+4, 2+5, 3+6 + np.testing.assert_array_equal(feature_sum_values, expected_feature_sums) + + def test_assign_predictions_with_multiple_predict_fn_models(self): + """ + Test assigning predictions from multiple models created with predict_fn + """ + df = pd.DataFrame({"x1": [1, 2, 3], "x2": [4, 5, 6], "y": [0, 1, 0]}) + vm_dataset = DataFrameDataset( + raw_dataset=df, target_column="y", feature_columns=["x1", "x2"] + ) + + # Define two different predict functions + def predict_fn_1(input_dict): + return 1 if input_dict["x1"] > 1.5 else 0 + + def predict_fn_2(input_dict): + return 1 if input_dict["x2"] > 4.5 else 0 + + vm_model_1 = init_model( + input_id="predict_fn_model_1", predict_fn=predict_fn_1, __log=False + ) + vm_model_2 = init_model( + input_id="predict_fn_model_2", predict_fn=predict_fn_2, __log=False + ) + + vm_dataset.assign_predictions(model=vm_model_1) + vm_dataset.assign_predictions(model=vm_model_2) + + self.assertEqual( + vm_dataset.prediction_column(vm_model_1), "predict_fn_model_1_prediction" + ) + self.assertEqual( + vm_dataset.prediction_column(vm_model_2), "predict_fn_model_2_prediction" + ) + + # Check that both prediction columns exist + self.assertTrue("predict_fn_model_1_prediction" in vm_dataset._df.columns) + self.assertTrue("predict_fn_model_2_prediction" in vm_dataset._df.columns) + + # Verify predictions are different based on the different logic + predictions_1 = vm_dataset.y_pred(vm_model_1) + predictions_2 = vm_dataset.y_pred(vm_model_2) + + expected_predictions_1 = [0, 1, 1] # x1 > 1.5: [1 -> 0, 2 -> 1, 3 -> 1] + expected_predictions_2 = [0, 1, 1] # x2 > 4.5: [4 -> 0, 5 -> 1, 6 -> 1] + + np.testing.assert_array_equal(predictions_1, expected_predictions_1) + np.testing.assert_array_equal(predictions_2, expected_predictions_2) + + def test_assign_predictions_with_predict_fn_and_prediction_values(self): + """ + Test assigning predictions with predict_fn model but using pre-computed prediction values + """ + df = pd.DataFrame({"x1": [1, 2, 3], "x2": [4, 5, 6], "y": [0, 1, 0]}) + vm_dataset = DataFrameDataset( + raw_dataset=df, target_column="y", feature_columns=["x1", "x2"] + ) + + # Define a predict function + def predict_fn(input_dict): + return 1 if input_dict["x1"] + input_dict["x2"] > 5 else 0 + + vm_model = init_model( + input_id="predict_fn_with_values", predict_fn=predict_fn, __log=False + ) + + # Pre-computed predictions (different from what the function would return) + precomputed_predictions = [1, 0, 1] + + with patch.object(vm_model, "predict") as mock_predict: + vm_dataset.assign_predictions( + model=vm_model, prediction_values=precomputed_predictions + ) + # The model's predict method should not be called + mock_predict.assert_not_called() + + self.assertEqual( + vm_dataset.prediction_column(vm_model), "predict_fn_with_values_prediction" + ) + + # Check that the precomputed predictions are used + self.assertTrue("predict_fn_with_values_prediction" in vm_dataset._df.columns) + np.testing.assert_array_equal( + vm_dataset.y_pred(vm_model), precomputed_predictions + ) + + def test_assign_predictions_with_invalid_predict_fn(self): + """ + Test assigning predictions with an invalid predict_fn (should raise error during model creation) + """ + # Try to create a model with a non-callable predict_fn + with self.assertRaises(ValueError) as context: + init_model(input_id="invalid_predict_fn", predict_fn="not_a_function", __log=False) + + self.assertIn("FunctionModel requires a callable predict_fn", str(context.exception)) + if __name__ == "__main__": unittest.main() diff --git a/validmind/models/function.py b/validmind/models/function.py index a8c6067a1..5b3e0f40f 100644 --- a/validmind/models/function.py +++ b/validmind/models/function.py @@ -35,7 +35,8 @@ class FunctionModel(VMModel): Attributes: predict_fn (callable): The predict function that should take a dictionary of - input features and return a prediction. + input features and return a prediction. Can return simple values or + dictionary objects. input_id (str, optional): The input ID for the model. Defaults to None. name (str, optional): The name of the model. Defaults to the name of the predict_fn. prompt (Prompt, optional): If using a prompt, the prompt object that defines the template @@ -55,6 +56,13 @@ def predict(self, X) -> List[Any]: X (pandas.DataFrame): The input features to predict on Returns: - List[Any]: The predictions + List[Any]: The predictions. Can contain simple values or dictionary objects + depending on what the predict_fn returns. """ - return [self.predict_fn(x) for x in X.to_dict(orient="records")] + predictions = [] + for x in X.to_dict(orient="records"): + result = self.predict_fn(x) + # Handle both simple values and complex dictionary returns + predictions.append(result) + + return predictions diff --git a/validmind/vm_models/dataset/dataset.py b/validmind/vm_models/dataset/dataset.py index d40c1d692..fea1566d3 100644 --- a/validmind/vm_models/dataset/dataset.py +++ b/validmind/vm_models/dataset/dataset.py @@ -8,7 +8,7 @@ import warnings from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import numpy as np import pandas as pd @@ -258,69 +258,91 @@ def with_options(self, **kwargs: Dict[str, Any]) -> "VMDataset": f"Options {kwargs} are not supported for this input" ) - def assign_predictions( - self, - model: VMModel, - prediction_column: Optional[str] = None, - prediction_values: Optional[List[Any]] = None, - probability_column: Optional[str] = None, - probability_values: Optional[List[float]] = None, - prediction_probabilities: Optional[ - List[float] - ] = None, # DEPRECATED: use probability_values - **kwargs: Dict[str, Any], - ) -> None: - """Assign predictions and probabilities to the dataset. - - Args: - model (VMModel): The model used to generate the predictions. - prediction_column (Optional[str]): The name of the column containing the predictions. - prediction_values (Optional[List[Any]]): The values of the predictions. - probability_column (Optional[str]): The name of the column containing the probabilities. - probability_values (Optional[List[float]]): The values of the probabilities. - prediction_probabilities (Optional[List[float]]): DEPRECATED: The values of the probabilities. - **kwargs: Additional keyword arguments that will get passed through to the model's `predict` method. - """ + def _handle_deprecated_parameters( + self, prediction_probabilities, probability_values + ): + """Handle deprecated parameters and return the correct probability values.""" if prediction_probabilities is not None: warnings.warn( "The `prediction_probabilities` argument is deprecated. Use `probability_values` instead.", DeprecationWarning, ) - probability_values = prediction_probabilities - - self._validate_assign_predictions( - model, - prediction_column, - prediction_values, - probability_column, - probability_values, - ) + return prediction_probabilities + return probability_values + def _check_existing_predictions(self, model): + """Check for existing predictions and probabilities, warn if overwriting.""" if self.prediction_column(model): logger.warning("Model predictions already assigned... Overwriting.") if self.probability_column(model): logger.warning("Model probabilities already assigned... Overwriting.") - # if the user passes a column name, we assume it has precomputed predictions + def _get_precomputed_values(self, prediction_column, probability_column): + """Get precomputed prediction and probability values from existing columns.""" + prediction_values = None + probability_values = None + if prediction_column: prediction_values = self._df[prediction_column].values if probability_column: probability_values = self._df[probability_column].values + return prediction_values, probability_values + + def _compute_predictions_if_needed(self, model, prediction_values, **kwargs): + """Compute predictions if not provided.""" if prediction_values is None: X = self.df if isinstance(model, (FunctionModel, PipelineModel)) else self.x - probability_values, prediction_values = compute_predictions( - model, X, **kwargs + return compute_predictions(model, X, **kwargs) + return None, prediction_values + + def _handle_dictionary_predictions(self, model, prediction_values): + """Handle dictionary predictions by converting to separate columns.""" + if ( + prediction_values is not None + and len(prediction_values) > 0 + and isinstance(prediction_values[0], dict) + ): + df_prediction_values = pd.DataFrame.from_dict( + prediction_values, orient="columns" ) - prediction_column = prediction_column or f"{model.input_id}_prediction" + for column_name in df_prediction_values.columns.tolist(): + values = df_prediction_values[column_name].values + + if column_name == "prediction": + prediction_column = f"{model.input_id}_prediction" + self._add_column(prediction_column, values) + self.prediction_column(model, prediction_column) + else: + self._add_column(f"{model.input_id}_{column_name}", values) + + return ( + True, + None, + ) # Return True to indicate dictionary handled, None for prediction_column + return False, None + + def _add_prediction_columns( + self, + model, + prediction_column, + prediction_values, + probability_column, + probability_values, + ): + """Add prediction and probability columns to the dataset.""" + if prediction_column is None: + prediction_column = f"{model.input_id}_prediction" + self._add_column(prediction_column, prediction_values) self.prediction_column(model, prediction_column) if probability_values is not None: - probability_column = probability_column or f"{model.input_id}_probabilities" + if probability_column is None: + probability_column = f"{model.input_id}_probabilities" self._add_column(probability_column, probability_values) self.probability_column(model, probability_column) else: @@ -329,6 +351,91 @@ def assign_predictions( "Not adding probability column to the dataset." ) + def assign_predictions( + self, + model: VMModel, + prediction_column: Optional[str] = None, + prediction_values: Optional[Any] = None, + probability_column: Optional[str] = None, + probability_values: Optional[Any] = None, + prediction_probabilities: Optional[ + Any + ] = None, # DEPRECATED: use probability_values + **kwargs: Dict[str, Any], + ) -> None: + """Assign predictions and probabilities to the dataset. + + Args: + model (VMModel): The model used to generate the predictions. + prediction_column (Optional[str]): The name of the column containing the predictions. + prediction_values (Optional[Any]): The values of the predictions. Can be array-like (list, numpy array, pandas Series, etc.). + probability_column (Optional[str]): The name of the column containing the probabilities. + probability_values (Optional[Any]): The values of the probabilities. Can be array-like (list, numpy array, pandas Series, etc.). + prediction_probabilities (Optional[Any]): DEPRECATED: The values of the probabilities. Use probability_values instead. + **kwargs: Additional keyword arguments that will get passed through to the model's `predict` method. + """ + # Handle deprecated parameters + probability_values = self._handle_deprecated_parameters( + prediction_probabilities, probability_values + ) + + # Convert pandas Series to numpy array for prediction_values + if ( + hasattr(prediction_values, "values") + and hasattr(prediction_values, "index") + and hasattr(prediction_values, "dtype") + ): + prediction_values = prediction_values.values + + # Convert pandas Series to numpy array for probability_values + if ( + hasattr(probability_values, "values") + and hasattr(probability_values, "index") + and hasattr(probability_values, "dtype") + ): + probability_values = probability_values.values + + # Validate input parameters + self._validate_assign_predictions( + model, + prediction_column, + prediction_values, + probability_column, + probability_values, + ) + + # Check for existing predictions and warn if overwriting + self._check_existing_predictions(model) + + # Get precomputed values if column names are provided + if prediction_column or probability_column: + prediction_values, prob_values_from_column = self._get_precomputed_values( + prediction_column, probability_column + ) + if prob_values_from_column is not None: + probability_values = prob_values_from_column + + # Compute predictions if not provided + if prediction_values is None: + probability_values, prediction_values = self._compute_predictions_if_needed( + model, prediction_values, **kwargs + ) + + # Handle dictionary predictions + is_dict_handled, _ = self._handle_dictionary_predictions( + model, prediction_values + ) + + # Add prediction and probability columns (skip if dictionary was handled) + if not is_dict_handled: + self._add_prediction_columns( + model, + prediction_column, + prediction_values, + probability_column, + probability_values, + ) + def prediction_column(self, model: VMModel, column_name: str = None) -> str: """Get or set the prediction column for a model.""" if column_name and column_name not in self.columns: