Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1b3f67a
support agent use case
AnilSorathiya Jun 24, 2025
723fcab
wrapper function for agent
AnilSorathiya Jun 24, 2025
28d9fbb
ragas metrics
AnilSorathiya Jun 30, 2025
ecf8e09
update ragas metrics
AnilSorathiya Jun 30, 2025
53e8879
fix lint error
AnilSorathiya Jun 30, 2025
1662368
create helper functions
AnilSorathiya Jul 1, 2025
cc84cbc
Merge branch 'main' into anilsorathiya/sc-10863/add-support-for-llm-a…
AnilSorathiya Jul 2, 2025
6f09780
delete old notebook
AnilSorathiya Jul 2, 2025
0bb731e
update description for each section
AnilSorathiya Jul 2, 2025
e758979
simplify agent
AnilSorathiya Jul 9, 2025
7c35cfe
simple demo notebook using langchain agent
AnilSorathiya Jul 10, 2025
9bb70e9
Update description of the simplified langgraph agent demo notebook
AnilSorathiya Jul 10, 2025
894d52a
add brief description to tests
AnilSorathiya Jul 14, 2025
d86a9af
add brief description to tests
AnilSorathiya Jul 14, 2025
884000f
Allow dict return type predict_fn
AnilSorathiya Jul 17, 2025
fbd5aa9
update notebook and refactor utils
AnilSorathiya Jul 18, 2025
daceabf
lint fix
AnilSorathiya Jul 18, 2025
5f8823a
Merge branch 'main' into anilsorathiya/sc-11324/extend-the-predict-fn…
AnilSorathiya Jul 18, 2025
70a5636
fix the test failure
AnilSorathiya Jul 18, 2025
33b06fb
new unit tests for multiple columns return in assign_predictions
AnilSorathiya Jul 18, 2025
8e12bd2
update notebooks to return multiple values in predict_fn
AnilSorathiya Jul 18, 2025
cd29fca
append input_id in column names
AnilSorathiya Jul 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 70 additions & 66 deletions notebooks/agents/langchain_agent_simple_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,10 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Dict, Any\n",
"from typing import Optional, Dict, Any\n",
"from langchain.tools import tool\n",
"from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage\n",
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"from langchain_openai import ChatOpenAI\n",
"import json\n",
"import pandas as pd\n",
"\n",
"# Load environment variables if using .env file\n",
"try:\n",
Expand All @@ -88,11 +86,31 @@
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"vscode": {
"languageId": "raw"
}
},
"source": [
"## LLM-Powered Tool Selection Router\n",
"\n",
"This section demonstrates how to create an intelligent router that uses an LLM to select the most appropriate tool based on user input and tool docstrings.\n",
"\n",
"### Benefits of LLM-Based Tool Selection:\n",
"- **Intelligent Routing**: Understanding of natural language intent\n",
"- **Dynamic Selection**: Can handle complex, multi-step requests \n",
"- **Context Awareness**: Considers conversation history and context\n",
"- **Flexible Matching**: Not limited to keyword patterns\n",
"- **Tool Documentation**: Uses actual tool docstrings for decision making\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple Tools with Rich Docstrings\n",
"## Simplified Tools with Rich Docstrings\n",
"\n",
"We've simplified the agent to use only two core tools:\n",
"- **search_engine**: For searching through documents, policies, and knowledge base \n",
Expand Down Expand Up @@ -208,7 +226,7 @@
" Create a timeline, 4) Start with the most critical parts, 5) Review and adjust as needed.\n",
" \"\"\"\n",
"\n",
"# Collect all tools for the LLM - SIMPLIFIED TO ONLY 2 TOOLS\n",
"# Collect all tools for the LLM router - SIMPLIFIED TO ONLY 2 TOOLS\n",
"AVAILABLE_TOOLS = [\n",
" search_engine,\n",
" task_assistant\n",
Expand All @@ -233,7 +251,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"def create_intelligent_langchain_agent():\n",
" \"\"\"Create a simplified LangChain agent with direct tool calling.\"\"\"\n",
" \n",
Expand All @@ -251,7 +268,7 @@
" - Use for: finding company policies, technical documentation, compliance documents\n",
" - Examples: \"Find our data privacy policy\", \"Search for API documentation\"\n",
"\n",
" 🎯 **task_assistant** - General-purpose task assistance and problem-solving \n",
" **task_assistant** - General-purpose task assistance and problem-solving \n",
" - Use for: guidance, recommendations, explaining concepts, planning activities\n",
" - Examples: \"How to prepare for an interview\", \"Help plan a meeting\", \"Explain machine learning\"\n",
"\n",
Expand All @@ -278,7 +295,7 @@
" # Get initial response from LLM\n",
" response = llm_with_tools.invoke(messages)\n",
" messages.append(response)\n",
" \n",
" tools_used = []\n",
" # Check if the LLM wants to use tools\n",
" if hasattr(response, 'tool_calls') and response.tool_calls:\n",
" # Execute tool calls\n",
Expand All @@ -288,11 +305,13 @@
" for tool in AVAILABLE_TOOLS:\n",
" if tool.name == tool_call['name']:\n",
" tool_to_call = tool\n",
" tools_used.append(tool_to_call.name)\n",
" break\n",
" \n",
" if tool_to_call:\n",
" # Execute the tool\n",
" try:\n",
"\n",
" tool_result = tool_to_call.invoke(tool_call['args'])\n",
" # Add tool message to conversation\n",
" from langchain_core.messages import ToolMessage\n",
Expand All @@ -314,7 +333,8 @@
" \"messages\": messages,\n",
" \"user_input\": user_input,\n",
" \"session_id\": session_id,\n",
" \"context\": {}\n",
" \"context\": {},\n",
" \"tools_used\": tools_used\n",
" }\n",
" \n",
" return invoke_agent\n",
Expand Down Expand Up @@ -369,7 +389,7 @@
" # Invoke the agent with the user input\n",
" result = intelligent_agent(user_input, session_id)\n",
" \n",
" return result\n",
" return {\"prediction\": result['messages'][-1].content, \"output\": result, \"tools_used\": result['tools_used']}\n",
"\n",
"\n",
"vm_intelligent_model = vm.init_model(input_id=\"financial_model\", predict_fn=agent_fn)\n",
Expand All @@ -382,11 +402,25 @@
"metadata": {},
"source": [
"## Prepare Sample Test Dataset\n",
"Now we'll create a test dataset to validate our agent's behavior. This dataset includes:\n",
"- Various user queries that test different agent capabilities\n",
"- Expected tools that should be used for each query\n",
"- Possible valid outputs for each query\n",
"- Unique session IDs to track conversation threads"
"\n",
"We'll create a comprehensive test dataset to evaluate our agent's performance across different scenarios. This dataset includes:\n",
"\n",
"**Diverse Test Cases**: Various types of user requests that test different agent capabilities:\n",
"- **Single Tool Requests**: Simple queries that require one specific tool\n",
"- **Multi-Tool Requests**: Complex queries requiring multiple tools in sequence \n",
"- **Validation Tasks**: Requests for data validation and verification\n",
"- **General Assistance**: Open-ended questions for problem-solving guidance\n",
"\n",
"**Expected Outputs**: For each test case, we define:\n",
"- **Expected Tools**: Which tools should be selected by the router\n",
"- **Possible Outputs**: Valid response patterns or values\n",
"- **Session IDs**: Unique identifiers for conversation tracking\n",
"\n",
"**Test Coverage**: The dataset covers:\n",
"- Document retrieval (search_engine tool)\n",
"- General guidance (task_assistant tool)\n",
"\n",
"This structured approach allows us to systematically evaluate both tool selection accuracy and response quality."
]
},
{
Expand All @@ -403,13 +437,13 @@
" {\n",
" \"input\": \"Find our company's data privacy policy\",\n",
" \"expected_tools\": [\"search_engine\"],\n",
" \"possible_outputs\": [\"privacy_policy\", \"data_protection\", \"company_privacy_guidelines\"],\n",
" \"possible_outputs\": [\"privacy_policy.pdf\", \"data_protection.doc\", \"company_privacy_guidelines.txt\"],\n",
" \"session_id\": str(uuid.uuid4())\n",
" },\n",
" {\n",
" \"input\": \"Search for loan approval procedures\", \n",
" \"expected_tools\": [\"search_engine\"],\n",
" \"possible_outputs\": [\"loan_procedures\", \"approval_process\", \"lending_guidelines\"],\n",
" \"possible_outputs\": [\"loan_procedures.doc\", \"approval_process.pdf\", \"lending_guidelines.txt\"],\n",
" \"session_id\": str(uuid.uuid4())\n",
" },\n",
" {\n",
Expand All @@ -433,7 +467,7 @@
" {\n",
" \"input\": \"Find technical documentation about API endpoints\",\n",
" \"expected_tools\": [\"search_engine\"],\n",
" \"possible_outputs\": [\"API_documentation\", \"REST_endpoints\", \"technical\"],\n",
" \"possible_outputs\": [\"API_documentation.pdf\", \"REST_endpoints.doc\", \"technical_guide.txt\"],\n",
" \"session_id\": str(uuid.uuid4())\n",
" },\n",
" {\n",
Expand Down Expand Up @@ -549,27 +583,6 @@
"vm_test_dataset._df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Agent prediction column adjustment in dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = vm_test_dataset._df['financial_model_prediction']\n",
"predictions = [row['messages'][-1].content for row in output]\n",
"\n",
"vm_test_dataset._df['output'] = output\n",
"vm_test_dataset._df['financial_model_prediction'] = predictions\n",
"vm_test_dataset._df.head(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -722,21 +735,24 @@
"source": [
"## Tool Call Accuracy Test\n",
"\n",
"This test evaluates how accurately our LangChain agent calls the expected tools for different inputs. It's a validation step that measures:\n",
"This test evaluates how accurately our intelligent router selects the correct tools for different user requests. It's a critical validation step that measures:\n",
"\n",
"**Tool Selection Performance**: Analyzes whether the agent correctly identifies and calls the expected tools\n",
"- **Expected vs. Actual**: Compares tools that should be called with tools that were actually called\n",
"- **Accuracy Scoring**: Calculates percentage accuracy for tool selection decisions\n",
"- **Multi-tool Handling**: Evaluates performance on requests requiring multiple tools\n",
"\n",
"**Tool Call Validation**: Analyzes the actual tool calls made by the agent\n",
"- **Tool Call Extraction**: Extracts tool calls from agent message history\n",
"- **Expected vs Actual Comparison**: Compares expected tools against tools that were actually called\n",
"- **Accuracy Calculation**: Computes accuracy as ratio of matched tools to expected tools\n",
"- **Multi-tool Support**: Handles cases with multiple expected tool calls\n",
"**Router Intelligence Assessment**: Validates the LLM-powered routing system's effectiveness\n",
"- **Intent Recognition**: How well the router understands user intent from natural language\n",
"- **Tool Mapping**: Accuracy of mapping user needs to appropriate tool capabilities\n",
"- **Decision Quality**: Assessment of routing confidence and reasoning\n",
"\n",
"**Results Analysis**: For each test case, provides:\n",
"- **Accuracy Score**: Percentage of expected tools that were correctly called\n",
"- **Expected Tools**: List of tools that should have been called\n",
"- **Found Tools**: List of tools that were actually called\n",
"- **Match Details**: Number of matches and total expected tools\n",
"**Failure Analysis**: Identifies patterns in incorrect tool selections to improve the routing logic\n",
"- **Missed Tools**: Cases where expected tools weren't selected\n",
"- **Extra Tools**: Cases where unnecessary tools were selected \n",
"- **Wrong Tools**: Cases where completely incorrect tools were selected\n",
"\n",
"The test processes a dataset containing agent outputs and expected tool lists, providing quantitative feedback on the agent's ability to call the right tools for each input."
"This test provides quantitative feedback on the agent's core intelligence - its ability to understand what users need and select the right tools to help them."
]
},
{
Expand Down Expand Up @@ -854,7 +870,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain_utils import capture_tool_output_messages, extract_tool_results_only, get_final_agent_response, format_tool_outputs_for_display\n",
"from notebooks.agents.langchain_utils import capture_tool_output_messages\n",
"\n",
"tool_messages = []\n",
"for i, row in vm_test_dataset._df.iterrows():\n",
Expand All @@ -863,22 +879,10 @@
" result = row['output']\n",
" # Capture all tool outputs and metadata\n",
" captured_data = capture_tool_output_messages(result)\n",
"\n",
" # Get just the tool results in a simple format\n",
" tool_results = extract_tool_results_only(result)\n",
"\n",
" # Get the final agent response\n",
" final_response = get_final_agent_response(result)\n",
"\n",
" # Print formatted summary\n",
" # print(format_tool_outputs_for_display(captured_data))\n",
"\n",
" \n",
" # Access specific tool outputs\n",
" for output in captured_data[\"tool_outputs\"]:\n",
" # print(f\"Tool: {output['tool_name']}\")\n",
" # print(f\"Output: {output['content']}\")\n",
" tool_message += output['content']\n",
" # print(\"-\" * 30)\n",
" tool_messages.append([tool_message])\n",
"\n",
"vm_test_dataset._df['tool_messages'] = tool_messages"
Expand Down
73 changes: 5 additions & 68 deletions notebooks/agents/langchain_utils.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,29 @@
from typing import Dict, List, Any
from langchain_core.messages import ToolMessage, AIMessage
from typing import Dict, Any
from langchain_core.messages import ToolMessage


def capture_tool_output_messages(agent_result: Dict[str, Any]) -> Dict[str, Any]:
"""
Capture all tool outputs and metadata from agent results.

Args:
agent_result: The result from the LangChain agent execution

Returns:
Dictionary containing tool outputs and metadata
"""
messages = agent_result.get('messages', [])
tool_outputs = []

for message in messages:
if isinstance(message, ToolMessage):
tool_outputs.append({
'tool_name': 'unknown', # ToolMessage doesn't directly contain tool name
'content': message.content,
'tool_call_id': getattr(message, 'tool_call_id', None)
})

return {
'tool_outputs': tool_outputs,
'total_messages': len(messages),
'tool_message_count': len(tool_outputs)
}


def extract_tool_results_only(agent_result: Dict[str, Any]) -> List[str]:
"""
Extract just the tool results in a simple format.

Args:
agent_result: The result from the LangChain agent execution

Returns:
List of tool result strings
"""
messages = agent_result.get('messages', [])
tool_results = []

for message in messages:
if isinstance(message, ToolMessage):
tool_results.append(message.content)

return tool_results


def get_final_agent_response(agent_result: Dict[str, Any]) -> str:
"""
Get the final agent response from the conversation.

Args:
agent_result: The result from the LangChain agent execution

Returns:
The final response content as a string
"""
messages = agent_result.get('messages', [])

# Look for the last AI message
for message in reversed(messages):
if isinstance(message, AIMessage):
return message.content

return "No final response found"


def format_tool_outputs_for_display(captured_data: Dict[str, Any]) -> str:
"""
Format tool outputs for readable display.

Args:
captured_data: Data from capture_tool_output_messages

Returns:
Formatted string for display
"""
output = "Tool Execution Summary:\n"
output += f"Total messages: {captured_data['total_messages']}\n"
output += f"Tool messages: {captured_data['tool_message_count']}\n\n"

for i, tool_output in enumerate(captured_data['tool_outputs'], 1):
output += f"Tool {i}: {tool_output['tool_name']}\n"
output += f"Output: {tool_output['content']}\n"
output += "-" * 30 + "\n"

return output
Loading