diff --git a/examples/customer_support_agent/.env.example b/examples/customer_support_agent/.env.example new file mode 100644 index 0000000..ec570b3 --- /dev/null +++ b/examples/customer_support_agent/.env.example @@ -0,0 +1,11 @@ +HIGHFLAME_API_KEY= +LLM_API_KEY= + +HIGHFLAME_ROUTE= +MODEL= + +LOG_LEVEL=DEBUG +MCP_SERVER_HOST=0.0.0.0 +MCP_SERVER_PORT=9000 +MCP_SERVER_URL=http://0.0.0.0:9000/mcp +DATABASE_PATH=./src/db/support_agent.db \ No newline at end of file diff --git a/examples/customer_support_agent/.gitignore b/examples/customer_support_agent/.gitignore new file mode 100644 index 0000000..5d2b787 --- /dev/null +++ b/examples/customer_support_agent/.gitignore @@ -0,0 +1,65 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +ENV/ +env/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Environment variables +.env +.env.local + +# Database +*.db +*.sqlite +*.sqlite3 + +# Logs +*.log +logs/ +!logs/.gitkeep +/tmp/ + +# OS +.DS_Store +Thumbs.db + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ + +# MCP Server logs +/tmp/mcp_server.log + +# Temporary files +*.tmp +*.bak diff --git a/examples/customer_support_agent/README.md b/examples/customer_support_agent/README.md new file mode 100644 index 0000000..3eda3c3 --- /dev/null +++ b/examples/customer_support_agent/README.md @@ -0,0 +1,614 @@ +# Customer Support Agent with LangGraph and MCP + +An AI-powered customer support agent built with LangGraph and Model Context Protocol (MCP) that handles customer service scenarios including order inquiries, technical support, billing issues, and general questions. The agent uses Highflame as a unified LLM provider supporting both OpenAI and Google Gemini models. + +## Features + +- **Highflame LLM Integration**: Unified provider for OpenAI and Google Gemini models +- **MCP Architecture**: Database operations isolated in MCP server for scalability and separation of concerns +- **Intelligent Conversation Flow**: LangGraph manages complex conversation states and routing +- **Conversation Memory**: Maintains context across multiple turns using LangGraph's MemorySaver +- **11 Database Tools**: Via MCP server (orders, customers, tickets, knowledge base) +- **Direct Tools**: Web search and email capabilities +- **REST API**: FastAPI server with `/chat` and `/generate` endpoints +- **Streamlit UI**: Interactive chat interface with conversation history and debug views +- **Comprehensive Logging**: Timestamped logs saved to `logs/` directory + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Agent Application │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ LangGraph Agent (State Machine) │ │ +│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │ +│ │ │ Start │→ │Understand │→ │ Call │ │ │ +│ │ │ Node │ │ Intent │ │ Tools │ │ │ +│ │ └──────────┘ └──────────┘ └──────────┘ │ │ +│ │ ↓ │ │ +│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │ +│ │ │ Synthesize│← │ Tools │← │ Tool │ │ │ +│ │ │ Response │ │ Node │ │ Calls │ │ │ +│ │ └──────────┘ └──────────┘ └──────────┘ │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Highflame LLM (Unified Provider) │ │ +│ │ • OpenAI Route (gpt-4o-mini, etc.) │ │ +│ │ • Google Route (gemini-2.5-flash-lite, etc.) │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Tool Router │ │ +│ │ ├── MCP Client → MCP Server (11 DB tools) │ │ +│ │ └── Direct Tools (web search, email) │ │ +│ └──────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ MCP Server (Port 9000) │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Knowledge Base Tools (2) │ │ +│ │ • search_knowledge_base_tool │ │ +│ │ • get_knowledge_base_by_category_tool │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Order Tools (3) │ │ +│ │ • lookup_order_tool │ │ +│ │ • get_order_status_tool │ │ +│ │ • get_order_history_tool │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Customer Tools (3) │ │ +│ │ • lookup_customer_tool │ │ +│ │ • get_customer_profile_tool │ │ +│ │ • create_customer_tool │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Ticket Tools (3) │ │ +│ │ • create_ticket_tool │ │ +│ │ • update_ticket_tool │ │ +│ │ • get_ticket_tool │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ ↓ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ SQLite Database │ │ +│ │ • customers, orders, tickets, knowledge_base │ │ +│ └──────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Quick Start + +### Prerequisites + +- Python 3.11 or higher +- pip package manager +- Highflame API key +- OpenAI API key (for `openai` route) or Google Gemini API key (for `google` route) + +### Step 1: Install Dependencies + +```bash +# Clone or navigate to the project directory +cd agent + +# Install all required packages +pip install -r requirements.txt +``` + +### Step 2: Configure Environment Variables + +Create a `.env` file in the project root directory: + +```env +# Highflame Configuration (Required) +HIGHFLAME_API_KEY=your_highflame_api_key_here +HIGHFLAME_ROUTE=google # or 'openai' +MODEL=gemini-2.5-flash-lite # or 'gpt-4o-mini' for OpenAI route +LLM_API_KEY=your_openai_or_gemini_api_key_here + +# MCP Server Configuration +MCP_SERVER_URL=http://localhost:9000/mcp +MCP_SERVER_HOST=0.0.0.0 +MCP_SERVER_PORT=9000 + +# Database Configuration +DATABASE_PATH=./src/db/support_agent.db + +# API Server Configuration (Optional) +PORT=8000 + +# Email Configuration (Optional - for email tool) +SMTP_SERVER=smtp.gmail.com +SMTP_PORT=587 +SMTP_USERNAME=your_email@gmail.com +SMTP_PASSWORD=your_app_password +FROM_EMAIL=your_email@gmail.com + +# Logging Configuration (Optional) +LOG_LEVEL=INFO # DEBUG, INFO, WARNING, ERROR +``` + +**Important Notes:** +- `HIGHFLAME_ROUTE` must be either `openai` or `google` +- For `openai` route: `LLM_API_KEY` should be your OpenAI API key, `MODEL` should be an OpenAI model (e.g., `gpt-4o-mini`) +- For `google` route: `LLM_API_KEY` should be your Google Gemini API key, `MODEL` should be a Gemini model (e.g., `gemini-2.5-flash-lite`) + +### Step 3: Start All Services + +**Option A: Using the Startup Script (Recommended)** + +```bash +# Make the script executable +chmod +x start.sh + +# Start all services +./start.sh +``` + +This will start: +1. MCP Server on port 9000 +2. FastAPI Server on port 8000 +3. Streamlit UI on port 8501 + +**Option B: Manual Startup** + +Terminal 1 - MCP Server: +```bash +python -m src.mcp_server.server +``` + +Terminal 2 - FastAPI Server: +```bash +uvicorn src.api:app --reload --host 0.0.0.0 --port 8000 +``` + +Terminal 3 - Streamlit UI: +```bash +streamlit run src/ui/app.py +``` + +### Step 4: Access the Services + +- **Streamlit UI**: http://localhost:8501 +- **FastAPI Server**: http://localhost:8000 +- **API Documentation**: http://localhost:8000/docs +- **MCP Server**: http://localhost:9000/mcp + +## Project Structure + +``` +agent/ +├── src/ +│ ├── agent/ +│ │ ├── database/ # Database models and queries +│ │ │ ├── models.py # SQLAlchemy ORM models +│ │ │ ├── queries.py # Database query functions +│ │ │ └── setup.py # Database initialization +│ │ ├── tools/ # Direct tools (no DB access) +│ │ │ ├── web_search.py # DuckDuckGo search +│ │ │ └── email.py # SMTP email sending +│ │ ├── graph.py # LangGraph agent definition +│ │ ├── llm.py # Highflame LLM integration +│ │ ├── mcp_tools.py # MCP client wrapper +│ │ └── state.py # Agent state schema +│ ├── mcp_server/ +│ │ └── server.py # MCP server with 11 DB tools +│ ├── ui/ +│ │ ├── app.py # Streamlit chat UI +│ │ └── db_viewer.py # Database viewer +│ ├── utils/ +│ │ └── logger.py # Centralized logging +│ ├── api.py # FastAPI server +│ └── main.py # CLI interface +├── tests/ +│ ├── test_agent.py # Agent tests +│ └── test_mcp_server.py # MCP server tests +├── docs/ +│ └── API.md # API documentation +├── logs/ # Application logs (auto-generated) +├── requirements.txt +├── start.sh # Startup script +└── README.md +``` + +## Available Tools + +### MCP Server Tools (Database Operations) + +All database operations are handled through the MCP server for better scalability and separation of concerns. + +**Knowledge Base (2 tools)** +- `search_knowledge_base_tool` - Search help articles by query string +- `get_knowledge_base_by_category_tool` - Get articles filtered by category + +**Orders (3 tools)** +- `lookup_order_tool` - Get detailed order information by order number +- `get_order_status_tool` - Check current status of an order +- `get_order_history_tool` - Get complete order history for a customer + +**Customers (3 tools)** +- `lookup_customer_tool` - Find customer by email, phone, or ID +- `get_customer_profile_tool` - Get full customer profile with all details +- `create_customer_tool` - Create a new customer record + +**Tickets (3 tools)** +- `create_ticket_tool` - Create a new support ticket +- `update_ticket_tool` - Update ticket status, priority, or assignee +- `get_ticket_tool` - Get ticket details by ID + +### Direct Tools (No Database) + +These tools run directly in the agent without going through the MCP server. + +- `web_search_tool` - Search the web using DuckDuckGo +- `web_search_news_tool` - Search for news articles +- `send_email_tool` - Send emails via SMTP (requires SMTP configuration) + +## Configuration Details + +### Highflame LLM Configuration + +The agent uses Highflame as a unified provider. Configure it using these environment variables: + +**For OpenAI Route:** +```env +HIGHFLAME_API_KEY=your_highflame_key +HIGHFLAME_ROUTE=openai +MODEL=gpt-4o-mini +LLM_API_KEY=sk-your_openai_key_here +``` + +**For Google Route:** +```env +HIGHFLAME_API_KEY=your_highflame_key +HIGHFLAME_ROUTE=google +MODEL=gemini-2.5-flash-lite +LLM_API_KEY=your_gemini_api_key_here +``` + +### MCP Server Configuration + +The MCP server can run locally or remotely: + +**Local Deployment:** +```env +MCP_SERVER_URL=http://localhost:9000/mcp +MCP_SERVER_HOST=0.0.0.0 +MCP_SERVER_PORT=9000 +``` + +**Remote Deployment:** +```env +MCP_SERVER_URL=http://your-server-ip:9000/mcp +``` + +### Database Configuration + +The SQLite database is automatically initialized on first run: + +```env +DATABASE_PATH=./src/db/support_agent.db +``` + +The database includes mock data for testing. To reset the database, delete the `.db` file and restart the services. + +## Usage Examples + +### Using the Streamlit UI + +1. Start all services using `./start.sh` or manually +2. Open http://localhost:8501 in your browser +3. Start a new conversation or continue an existing one +4. Ask questions like: + - "What is the status of order ORD-001?" + - "Tell me about customer ID 10" + - "Create a support ticket for a damaged item" + - "Search the web for current weather in New York" + +### Using the REST API + +**Stateful Chat (maintains conversation context):** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "What is the status of order ORD-001?", + "thread_id": "user-123", + "customer_id": 1 + }' +``` + +**Stateless Query (no context):** +```bash +curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "message": "How do I reset my password?" + }' +``` + +See `docs/API.md` for complete API documentation. + +### Using the CLI + +**Interactive Mode:** +```bash +python src/main.py +``` + +**Test Suite:** +```bash +python src/main.py test +``` + +**Single Query:** +```bash +python src/main.py "What is the status of order ORD-001?" +``` + +## Testing + +### Run All Tests + +```bash +pytest +``` + +### Test MCP Server + +```bash +# Start server first (in one terminal) +python -m src.mcp_server.server + +# Run tests (in another terminal) +pytest tests/test_mcp_server.py -v +``` + +### Test Agent + +```bash +pytest tests/test_agent.py -v +``` + +### Manual Testing + +```bash +# Run comprehensive test suite +python src/main.py test +``` + +## API Endpoints + +### FastAPI Server Endpoints + +**GET /health** - Health check +```bash +curl http://localhost:8000/health +``` + +**GET /** - API information +```bash +curl http://localhost:8000/ +``` + +**POST /chat** - Stateful conversation +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{"message": "Hello", "thread_id": "test-123"}' +``` + +**POST /generate** - Stateless query +```bash +curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{"message": "What can you help me with?"}' +``` + +See `docs/API.md` for complete API documentation with request/response examples. + +## Database Schema + +The SQLite database includes the following tables: + +- **customers** - Customer information (name, email, phone, address) +- **orders** - Order details (order number, status, items, dates) +- **tickets** - Support tickets (status, priority, description, assignments) +- **knowledge_base** - Help articles and FAQs (title, content, category) +- **conversations** - Chat history (messages, customer associations) + +The database is automatically initialized with mock data on first run. You can view the database using: + +```bash +streamlit run src/ui/db_viewer.py +``` + +## Logging + +All application logs are saved to the `logs/` directory with timestamped filenames: + +``` +logs/ +├── app_20251230_093000.log +├── app_20251230_094500.log +└── ... +``` + +Log levels can be configured via `LOG_LEVEL` environment variable: +- `DEBUG` - Detailed debugging information +- `INFO` - General informational messages +- `WARNING` - Warning messages +- `ERROR` - Error messages only + +## Development + +### Adding New Tools + +**For database tools** - Add to `src/mcp_server/server.py`: + +```python +@mcp.tool() +def my_new_tool(arg1: str, arg2: int) -> str: + """Tool description for the LLM.""" + db = get_session() + try: + # Your database logic here + result = perform_query(db, arg1, arg2) + return result + finally: + db.close() +``` + +Then add wrapper in `src/agent/mcp_tools.py`: + +```python +class MyNewToolInput(BaseModel): + arg1: str = Field(description="First argument") + arg2: int = Field(description="Second argument") + +my_new_tool = StructuredTool( + name="my_new_tool", + description="Tool description", + args_schema=MyNewToolInput, + func=lambda arg1, arg2: call_mcp_tool("my_new_tool", arg1=arg1, arg2=arg2), +) +``` + +**For direct tools** - Add to `src/agent/tools/`: + +```python +from langchain_core.tools import tool + +@tool +def my_direct_tool(query: str) -> str: + """Tool description for the LLM.""" + # Your logic here + return result +``` + +Then import and add to `get_tools()` in `src/agent/graph.py`. + +### Code Quality + +```bash +# Format code +black src/ tests/ + +# Lint +flake8 src/ tests/ + +# Type check +mypy src/ +``` + +## Troubleshooting + +### MCP Server Not Starting + +**Check if port 9000 is available:** +```bash +lsof -i :9000 +``` + +**Kill process if needed:** +```bash +kill -9 +``` + +**Check MCP server logs:** +```bash +tail -f logs/app_*.log +``` + +### API Server Not Starting + +**Check if port 8000 is available:** +```bash +lsof -i :8000 +``` + +**Verify environment variables:** +```bash +python -c "import os; from dotenv import load_dotenv; load_dotenv(); print('HIGHFLAME_API_KEY:', bool(os.getenv('HIGHFLAME_API_KEY')))" +``` + +### Connection Refused to MCP Server + +1. Verify MCP server is running: `curl http://localhost:9000/mcp` +2. Check `MCP_SERVER_URL` in `.env` matches server address +3. Review MCP server logs in `logs/` directory + +### Tools Not Working + +1. Verify MCP server is running and accessible +2. Check `MCP_SERVER_URL` in `.env` is correct +3. Review agent logs for tool call errors +4. Test MCP server directly: `pytest tests/test_mcp_server.py` + +### API Key Errors + +Ensure all required API keys are set in `.env`: + +```bash +# Check if keys are loaded +python -c " +import os +from dotenv import load_dotenv +load_dotenv() +print('HIGHFLAME_API_KEY:', bool(os.getenv('HIGHFLAME_API_KEY'))) +print('HIGHFLAME_ROUTE:', os.getenv('HIGHFLAME_ROUTE')) +print('MODEL:', os.getenv('MODEL')) +print('LLM_API_KEY:', bool(os.getenv('LLM_API_KEY'))) +" +``` + +### Database Errors + +**Reset database:** +```bash +rm src/db/support_agent.db +# Restart services - database will be recreated with mock data +``` + +**View database:** +```bash +streamlit run src/ui/db_viewer.py +``` + +## Security Considerations + +- API keys stored in `.env` (not committed to git) +- Database operations isolated in MCP server +- Input validation on all tool parameters +- CORS enabled for API endpoints (configure as needed) +- Logs may contain sensitive information - secure the `logs/` directory + +## Performance + +- **MCP Server**: Handles database operations efficiently with connection pooling +- **Agent**: Uses LangGraph's optimized state management +- **Memory**: Conversation state persisted using LangGraph's MemorySaver +- **Logging**: Rotating file handler (10MB per file, 5 backups) + +## Requirements + +- Python 3.11+ +- Highflame API key +- OpenAI API key (for `openai` route) or Google Gemini API key (for `google` route) +- 100MB disk space for database and logs +- Network access for MCP server communication + +## Support + +For issues or questions: + +1. Check logs in `logs/` directory +2. Run test suite: `python src/main.py test` +3. View database: `streamlit run src/ui/db_viewer.py` +4. Review API documentation: `docs/API.md` + +## License + +Apache 2.0 diff --git a/examples/customer_support_agent/docs/API.md b/examples/customer_support_agent/docs/API.md new file mode 100644 index 0000000..bccc68e --- /dev/null +++ b/examples/customer_support_agent/docs/API.md @@ -0,0 +1,460 @@ +# Customer Support Agent API Documentation + +## Overview + +The Customer Support Agent provides a RESTful API for interacting with an AI-powered customer support system. The API supports both stateful conversations (maintaining context) and stateless queries. + +## Base URL + +``` +http://localhost:8000 +``` + +For production deployments, replace `localhost:8000` with your server address. + +## Authentication + +Currently, the API does not require authentication. For production use, implement authentication middleware. + +## Endpoints + +### 1. Health Check + +Check if the API is running and healthy. + +**Endpoint:** `GET /health` + +**Request:** +```bash +curl -X GET http://localhost:8000/health +``` + +**Response:** +```json +{ + "status": "healthy", + "service": "customer-support-agent" +} +``` + +**Status Codes:** +- `200 OK` - Service is healthy + +--- + +### 2. Root Endpoint + +Get API information and available endpoints. + +**Endpoint:** `GET /` + +**Request:** +```bash +curl -X GET http://localhost:8000/ +``` + +**Response:** +```json +{ + "message": "Customer Support Agent API", + "version": "1.0.0", + "endpoints": { + "/chat": "POST - Chat with the agent (maintains conversation state)", + "/generate": "POST - Generate a single response (no state)", + "/health": "GET - Health check" + } +} +``` + +**Status Codes:** +- `200 OK` - Success + +--- + +### 3. Chat Endpoint + +Chat with the agent while maintaining conversation state. The agent remembers previous messages in the conversation thread, allowing for natural multi-turn conversations. + +**Endpoint:** `POST /chat` + +**Request Body:** +```json +{ + "message": "What is the status of my order ORD-001?", + "thread_id": "user-123", + "customer_id": 1 +} +``` + +**Request Fields:** +- `message` (string, required): The user's message or question +- `thread_id` (string, optional): Unique identifier for the conversation thread. Defaults to `"default"` if not provided. Use the same `thread_id` across multiple requests to maintain conversation context. +- `customer_id` (integer, optional): Customer ID if known. Helps the agent access customer-specific information. + +**Example Request:** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "What is the status of my order ORD-001?", + "thread_id": "user-123", + "customer_id": 1 + }' +``` + +**Response:** +```json +{ + "response": "Your order ORD-001 is currently shipped and should arrive within 2-3 business days. You can track it using the tracking number provided in your confirmation email.", + "thread_id": "user-123", + "tool_calls": [ + { + "tool": "lookup_order_tool", + "args": { + "order_number": "ORD-001" + }, + "id": "call_abc123" + } + ], + "intent": "order_inquiry", + "confidence": 0.95 +} +``` + +**Response Fields:** +- `response` (string): The agent's response message +- `thread_id` (string): The conversation thread identifier (same as provided or default) +- `tool_calls` (array, optional): List of tools that were called during processing. Each tool call includes: + - `tool` (string): Name of the tool that was called + - `args` (object): Arguments passed to the tool + - `id` (string): Unique identifier for the tool call +- `intent` (string, optional): Detected intent of the user's message (e.g., `order_inquiry`, `technical_support`, `billing`, `general`) +- `confidence` (float, optional): Confidence score of the response (0.0 to 1.0) + +**Status Codes:** +- `200 OK` - Request processed successfully +- `400 Bad Request` - Invalid request body +- `500 Internal Server Error` - Server error processing request + +**Multi-turn Conversation Example:** + +**First Message:** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "I need help with my order", + "thread_id": "user-123" + }' +``` + +**Response:** +```json +{ + "response": "I'd be happy to help you with your order! Could you please provide your order number?", + "thread_id": "user-123", + "intent": "order_inquiry", + "confidence": 0.85 +} +``` + +**Follow-up Message (using same thread_id):** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "My order number is ORD-001", + "thread_id": "user-123" + }' +``` + +**Response:** +```json +{ + "response": "Thank you! I've looked up your order ORD-001. It's currently shipped and should arrive within 2-3 business days. The tracking number is TRACK123456.", + "thread_id": "user-123", + "tool_calls": [ + { + "tool": "lookup_order_tool", + "args": { + "order_number": "ORD-001" + }, + "id": "call_xyz789" + } + ], + "intent": "order_inquiry", + "confidence": 0.95 +} +``` + +--- + +### 4. Generate Endpoint + +Generate a single response without maintaining conversation state. Each request is treated as a new, independent conversation. + +**Endpoint:** `POST /generate` + +**Request Body:** +```json +{ + "message": "What is your return policy?", + "thread_id": "optional-unique-id", + "customer_id": 1 +} +``` + +**Request Fields:** +- `message` (string, required): The user's message or question +- `thread_id` (string, optional): If not provided, a unique ID is generated automatically. Note: Unlike `/chat`, this endpoint does not maintain state between requests. +- `customer_id` (integer, optional): Customer ID if known + +**Example Request:** +```bash +curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "message": "What is your return policy?" + }' +``` + +**Response:** +```json +{ + "response": "Our return policy allows returns within 30 days of purchase. Items must be in original condition with tags attached. Electronics can be returned for a full refund or exchange. To initiate a return, log into your account and go to the Orders section.", + "thread_id": "550e8400-e29b-41d4-a716-446655440000" +} +``` + +**Response Fields:** +- `response` (string): The agent's response message +- `thread_id` (string): The thread identifier used for this request (generated if not provided) + +**Status Codes:** +- `200 OK` - Request processed successfully +- `400 Bad Request` - Invalid request body +- `500 Internal Server Error` - Server error processing request + +--- + +## Example Use Cases + +### Order Inquiry + +**Check order status:** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "Where is my order ORD-001?", + "thread_id": "customer-456" + }' +``` + +**Get order history:** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "Show me my order history", + "thread_id": "customer-456", + "customer_id": 5 + }' +``` + +### Customer Information + +**Lookup customer:** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "Tell me about customer ID 10", + "thread_id": "support-agent-1" + }' +``` + +**Get customer profile:** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "Get the full profile for customer with email john@example.com", + "thread_id": "support-agent-1" + }' +``` + +### Technical Support + +**Create support ticket:** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "My product is not working. It won't turn on. Please create a support ticket.", + "thread_id": "customer-789", + "customer_id": 3 + }' +``` + +**Update ticket status:** +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "Update ticket 123 to resolved status", + "thread_id": "support-agent-2" + }' +``` + +### Billing Question + +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "I was charged twice for my order. Can you help?", + "thread_id": "customer-101", + "customer_id": 7 + }' +``` + +### Knowledge Base Search + +```bash +curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "message": "How do I track my order?" + }' +``` + +### Web Search + +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "What is the current weather in New York?", + "thread_id": "user-weather" + }' +``` + +--- + +## Error Responses + +### 400 Bad Request + +Invalid request body or missing required fields. + +**Response:** +```json +{ + "detail": "Invalid request body" +} +``` + +**Common Causes:** +- Missing `message` field +- Invalid JSON format +- Wrong data types for fields + +### 500 Internal Server Error + +Server error while processing the request. + +**Response:** +```json +{ + "detail": "Error processing chat request: [error message]" +} +``` + +**Common Causes:** +- MCP server not running or unreachable +- Database connection issues +- LLM API errors +- Missing environment variables + +--- + +## Response Fields Reference + +### Chat Response + +| Field | Type | Description | +|-------|------|-------------| +| `response` | string | The agent's response message | +| `thread_id` | string | The conversation thread identifier | +| `tool_calls` | array (optional) | List of tools that were called during processing | +| `intent` | string (optional) | Detected intent of the user's message | +| `confidence` | float (optional) | Confidence score of the response (0.0 to 1.0) | + +### Generate Response + +| Field | Type | Description | +|-------|------|-------------| +| `response` | string | The agent's response message | +| `thread_id` | string | The thread identifier used for this request | + +### Tool Call Object + +| Field | Type | Description | +|-------|------|-------------| +| `tool` | string | Name of the tool that was called | +| `args` | object | Arguments passed to the tool | +| `id` | string | Unique identifier for the tool call | + +--- + +## Intent Classification + +The agent automatically classifies user queries into the following intent categories: + +- `order_inquiry` - Questions about orders, shipping, delivery +- `technical_support` - Product issues, troubleshooting +- `billing` - Payment, charges, refunds +- `customer_inquiry` - Customer information requests +- `ticket_management` - Creating or updating support tickets +- `knowledge_base` - FAQ and help article searches +- `general` - General questions or unclear intent + +--- + +## Notes + +1. **Memory/State**: The `/chat` endpoint maintains conversation state using the `thread_id`. Use the same `thread_id` across multiple requests to maintain context. The agent remembers previous messages in the conversation. + +2. **Thread IDs**: For `/chat`, use a unique `thread_id` per user/session. For `/generate`, each request is independent and thread IDs are not used for state management. + +3. **Customer ID**: Providing `customer_id` helps the agent access customer-specific information like order history, profile details, and create tickets associated with the customer. + +4. **Tool Calls**: The agent automatically selects and uses appropriate tools based on the user's query. Tool calls are included in the response for transparency and debugging. + +5. **Intent Classification**: The agent automatically classifies user queries to better understand context and route to appropriate tools. + +6. **Confidence Scores**: Higher confidence scores (closer to 1.0) indicate the agent is more certain about its response. Lower scores may indicate ambiguous queries. + +7. **Rate Limiting**: Currently, there is no rate limiting. For production use, implement rate limiting middleware. + +8. **Timeout**: Requests may take several seconds depending on tool calls and LLM response time. Set appropriate timeout values in your HTTP client. + +--- + +## API Versioning + +Current API version: `1.0.0` + +API version information is available at the root endpoint (`GET /`). + +--- + +## Support + +For API issues: + +1. Check server logs in `logs/` directory +2. Verify all services are running (MCP server, API server) +3. Test health endpoint: `GET /health` +4. Review environment variable configuration diff --git a/examples/customer_support_agent/requirements.txt b/examples/customer_support_agent/requirements.txt new file mode 100644 index 0000000..943bbea --- /dev/null +++ b/examples/customer_support_agent/requirements.txt @@ -0,0 +1,17 @@ +langgraph>=0.2.0 +langchain-openai>=0.1.0 +langchain-google-genai>=2.0.10 +langchain>=0.2.0 +langchain-core>=0.2.0 +sqlalchemy>=2.0.0 +python-dotenv>=1.0.0 +ddgs>=1.0.0 +fastapi>=0.104.0 +uvicorn>=0.24.0 +pydantic>=2.0.0 +pytest>=7.4.0 +pytest-asyncio>=0.21.0 +streamlit>=1.28.0 +pandas>=2.0.0 +fastmcp>=2.0.0 + diff --git a/examples/customer_support_agent/src/agent/__init__.py b/examples/customer_support_agent/src/agent/__init__.py new file mode 100644 index 0000000..c27ebf0 --- /dev/null +++ b/examples/customer_support_agent/src/agent/__init__.py @@ -0,0 +1 @@ +"""Customer Support Agent using LangGraph.""" diff --git a/examples/customer_support_agent/src/agent/database/__init__.py b/examples/customer_support_agent/src/agent/database/__init__.py new file mode 100644 index 0000000..a7bf4e7 --- /dev/null +++ b/examples/customer_support_agent/src/agent/database/__init__.py @@ -0,0 +1 @@ +"""Database models and utilities.""" diff --git a/examples/customer_support_agent/src/agent/database/models.py b/examples/customer_support_agent/src/agent/database/models.py new file mode 100644 index 0000000..e6d7450 --- /dev/null +++ b/examples/customer_support_agent/src/agent/database/models.py @@ -0,0 +1,132 @@ +"""SQLAlchemy models for the customer support database.""" + +from datetime import datetime +from sqlalchemy import ( + Column, + Integer, + String, + Float, + DateTime, + ForeignKey, + Text, + JSON, + Enum, +) +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import relationship +import enum + +Base = declarative_base() + + +class TicketStatus(enum.Enum): + """Support ticket status enumeration.""" + + OPEN = "open" + IN_PROGRESS = "in_progress" + RESOLVED = "resolved" + CLOSED = "closed" + + +class TicketPriority(enum.Enum): + """Support ticket priority enumeration.""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + URGENT = "urgent" + + +class OrderStatus(enum.Enum): + """Order status enumeration.""" + + PENDING = "pending" + PROCESSING = "processing" + SHIPPED = "shipped" + DELIVERED = "delivered" + CANCELLED = "cancelled" + RETURNED = "returned" + + +class Customer(Base): + """Customer model.""" + + __tablename__ = "customers" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String(255), nullable=False) + email = Column(String(255), unique=True, nullable=False, index=True) + phone = Column(String(50), nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + + # Relationships + orders = relationship("Order", back_populates="customer") + tickets = relationship("Ticket", back_populates="customer") + conversations = relationship("Conversation", back_populates="customer") + + +class Order(Base): + """Order model.""" + + __tablename__ = "orders" + + id = Column(Integer, primary_key=True, index=True) + customer_id = Column(Integer, ForeignKey("customers.id"), nullable=False) + order_number = Column(String(100), unique=True, nullable=False, index=True) + status = Column(Enum(OrderStatus), default=OrderStatus.PENDING, nullable=False) + total = Column(Float, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + + # Relationships + customer = relationship("Customer", back_populates="orders") + tickets = relationship("Ticket", back_populates="order") + + +class Ticket(Base): + """Support ticket model.""" + + __tablename__ = "tickets" + + id = Column(Integer, primary_key=True, index=True) + customer_id = Column(Integer, ForeignKey("customers.id"), nullable=False) + order_id = Column(Integer, ForeignKey("orders.id"), nullable=True) + subject = Column(String(255), nullable=False) + description = Column(Text, nullable=False) + status = Column(Enum(TicketStatus), default=TicketStatus.OPEN, nullable=False) + priority = Column( + Enum(TicketPriority), default=TicketPriority.MEDIUM, nullable=False + ) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + # Relationships + customer = relationship("Customer", back_populates="tickets") + order = relationship("Order", back_populates="tickets") + + +class KnowledgeBase(Base): + """Knowledge base article model.""" + + __tablename__ = "knowledge_base" + + id = Column(Integer, primary_key=True, index=True) + title = Column(String(255), nullable=False) + content = Column(Text, nullable=False) + category = Column(String(100), nullable=True) + tags = Column(String(500), nullable=True) # Comma-separated tags + created_at = Column(DateTime, default=datetime.utcnow) + + +class Conversation(Base): + """Conversation history model.""" + + __tablename__ = "conversations" + + id = Column(Integer, primary_key=True, index=True) + customer_id = Column(Integer, ForeignKey("customers.id"), nullable=True) + messages = Column(JSON, nullable=False) # List of message dicts + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + # Relationships + customer = relationship("Customer", back_populates="conversations") diff --git a/examples/customer_support_agent/src/agent/database/queries.py b/examples/customer_support_agent/src/agent/database/queries.py new file mode 100644 index 0000000..831f180 --- /dev/null +++ b/examples/customer_support_agent/src/agent/database/queries.py @@ -0,0 +1,181 @@ +"""Database query helper functions.""" + +from sqlalchemy.orm import Session +from sqlalchemy import or_ +from typing import List, Optional +from .models import ( + Customer, + Order, + Ticket, + KnowledgeBase, + Conversation, + TicketStatus, + TicketPriority, + OrderStatus, +) + + +def get_customer_by_email(db: Session, email: str) -> Optional[Customer]: + """Get customer by email.""" + return db.query(Customer).filter(Customer.email == email).first() + + +def get_customer_by_id(db: Session, customer_id: int) -> Optional[Customer]: + """Get customer by ID.""" + return db.query(Customer).filter(Customer.id == customer_id).first() + + +def get_customer_by_phone(db: Session, phone: str) -> Optional[Customer]: + """Get customer by phone number.""" + return db.query(Customer).filter(Customer.phone == phone).first() + + +def create_customer( + db: Session, name: str, email: str, phone: Optional[str] = None +) -> Customer: + """Create a new customer.""" + customer = Customer(name=name, email=email, phone=phone) + db.add(customer) + db.commit() + db.refresh(customer) + return customer + + +def get_order_by_number(db: Session, order_number: str) -> Optional[Order]: + """Get order by order number.""" + return db.query(Order).filter(Order.order_number == order_number).first() + + +def get_orders_by_customer(db: Session, customer_id: int) -> List[Order]: + """Get all orders for a customer.""" + return db.query(Order).filter(Order.customer_id == customer_id).all() + + +def create_order( + db: Session, + customer_id: int, + order_number: str, + total: float, + status: OrderStatus = OrderStatus.PENDING, +) -> Order: + """Create a new order.""" + order = Order( + customer_id=customer_id, + order_number=order_number, + total=total, + status=status, + ) + db.add(order) + db.commit() + db.refresh(order) + return order + + +def get_ticket_by_id(db: Session, ticket_id: int) -> Optional[Ticket]: + """Get ticket by ID.""" + return db.query(Ticket).filter(Ticket.id == ticket_id).first() + + +def create_ticket( + db: Session, + customer_id: int, + subject: str, + description: str, + priority: TicketPriority = TicketPriority.MEDIUM, + order_id: Optional[int] = None, +) -> Ticket: + """Create a new support ticket.""" + ticket = Ticket( + customer_id=customer_id, + order_id=order_id, + subject=subject, + description=description, + priority=priority, + ) + db.add(ticket) + db.commit() + db.refresh(ticket) + return ticket + + +def update_ticket( + db: Session, + ticket_id: int, + status: Optional[TicketStatus] = None, + notes: Optional[str] = None, +) -> Optional[Ticket]: + """Update ticket status and add notes.""" + ticket = get_ticket_by_id(db, ticket_id) + if not ticket: + return None + + if status: + ticket.status = status + if notes: + # Append notes to description + ticket.description += f"\n\n[Update]: {notes}" + + from datetime import datetime + + ticket.updated_at = datetime.utcnow() + db.commit() + db.refresh(ticket) + return ticket + + +def search_knowledge_base( + db: Session, + query: str, + category: Optional[str] = None, + limit: int = 5, +) -> List[KnowledgeBase]: + """Search knowledge base by keyword matching in title and content.""" + search_filter = or_( + KnowledgeBase.title.ilike(f"%{query}%"), + KnowledgeBase.content.ilike(f"%{query}%"), + KnowledgeBase.tags.ilike(f"%{query}%"), + ) + + if category: + search_filter = search_filter & (KnowledgeBase.category == category) + + return db.query(KnowledgeBase).filter(search_filter).limit(limit).all() + + +def get_knowledge_base_by_category(db: Session, category: str) -> List[KnowledgeBase]: + """Get all knowledge base articles in a category.""" + return db.query(KnowledgeBase).filter(KnowledgeBase.category == category).all() + + +def create_conversation( + db: Session, messages: List[dict], customer_id: Optional[int] = None +) -> Conversation: + """Create a new conversation.""" + conversation = Conversation(customer_id=customer_id, messages=messages) + db.add(conversation) + db.commit() + db.refresh(conversation) + return conversation + + +def update_conversation( + db: Session, + conversation_id: int, + messages: List[dict], +) -> Optional[Conversation]: + """Update conversation messages.""" + conversation = ( + db.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) + if not conversation: + return None + + conversation.messages = messages + from datetime import datetime + + conversation.updated_at = datetime.utcnow() + db.commit() + db.refresh(conversation) + return conversation diff --git a/examples/customer_support_agent/src/agent/database/setup.py b/examples/customer_support_agent/src/agent/database/setup.py new file mode 100644 index 0000000..4110855 --- /dev/null +++ b/examples/customer_support_agent/src/agent/database/setup.py @@ -0,0 +1,370 @@ +"""Database setup and initialization.""" + +import os +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from .models import ( + Base, + Customer, + KnowledgeBase, + OrderStatus, + TicketPriority, +) +from .queries import create_customer, create_order, create_ticket +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +def get_database_path() -> str: + """Get database path from environment or use default.""" + path = os.getenv("DATABASE_PATH", "./src/db/support_agent.db") + logger.debug(f"Database path: {path}") + return path + + +def get_engine(): + """Create SQLAlchemy engine.""" + database_path = get_database_path() + # Ensure directory exists + os.makedirs(os.path.dirname(os.path.abspath(database_path)) or ".", exist_ok=True) + logger.debug(f"Creating database engine for: {database_path}") + return create_engine(f"sqlite:///{database_path}", echo=False) + + +def get_session(): + """Create database session.""" + engine = get_engine() + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + return SessionLocal() + + +def init_database(seed_data: bool = True): + """Initialize database with tables and optionally seed with sample data.""" + logger.info("Initializing database") + engine = get_engine() + + # Create all tables + logger.debug("Creating database tables") + Base.metadata.create_all(bind=engine) + logger.info("Database tables created successfully") + + if seed_data: + logger.debug("Seeding database with sample data") + seed_database() + + +def seed_database(): + """Seed database with sample data for testing.""" + db = get_session() + + try: + # Check if data already exists + customer_count = db.query(Customer).count() + if customer_count > 0: + logger.info( + f"Database already contains {customer_count} customers. " + "Skipping seed." + ) + return + + logger.info("Seeding database with sample data") + + # Create sample customers + customers_data = [ + ("John Doe", "john.doe@example.com", "+1-555-0101"), + ("Jane Smith", "jane.smith@example.com", "+1-555-0102"), + ("Bob Johnson", "bob.johnson@example.com", "+1-555-0103"), + ("Alice Williams", "alice.williams@example.com", "+1-555-0104"), + ("Charlie Brown", "charlie.brown@example.com", "+1-555-0105"), + ("Diana Prince", "diana.prince@example.com", "+1-555-0106"), + ("Edward Norton", "edward.norton@example.com", "+1-555-0107"), + ("Fiona Apple", "fiona.apple@example.com", "+1-555-0108"), + ("George Lucas", "george.lucas@example.com", "+1-555-0109"), + ("Helen Mirren", "helen.mirren@example.com", "+1-555-0110"), + ] + + customers = [] + for name, email, phone in customers_data: + customers.append(create_customer(db, name, email, phone)) + + customer1, customer2, customer3 = customers[0], customers[1], customers[2] + + # Create sample orders + orders_data = [ + (customer1.id, "ORD-001", 99.99, OrderStatus.SHIPPED), + (customer1.id, "ORD-002", 149.50, OrderStatus.PROCESSING), + (customer1.id, "ORD-003", 79.99, OrderStatus.DELIVERED), + (customer2.id, "ORD-004", 199.99, OrderStatus.SHIPPED), + (customer2.id, "ORD-005", 299.50, OrderStatus.PROCESSING), + (customer3.id, "ORD-006", 49.99, OrderStatus.DELIVERED), + (customer3.id, "ORD-007", 89.99, OrderStatus.CANCELLED), + (customers[3].id, "ORD-008", 159.99, OrderStatus.SHIPPED), + (customers[4].id, "ORD-009", 249.99, OrderStatus.PROCESSING), + (customers[5].id, "ORD-010", 179.99, OrderStatus.DELIVERED), + (customers[6].id, "ORD-011", 129.99, OrderStatus.SHIPPED), + (customers[7].id, "ORD-012", 349.99, OrderStatus.PROCESSING), + (customers[8].id, "ORD-013", 99.99, OrderStatus.RETURNED), + (customers[9].id, "ORD-014", 219.99, OrderStatus.DELIVERED), + ] + + orders = [] + for customer_id, order_num, total, status in orders_data: + orders.append(create_order(db, customer_id, order_num, total, status)) + + order1 = orders[0] + + # Create sample tickets + tickets_data = [ + ( + customer1.id, + "Order not received", + "I placed order ORD-001 two weeks ago but haven't received it yet.", + TicketPriority.HIGH, + order1.id, + ), + ( + customer2.id, + "Product question", + "What is the return policy for electronics?", + TicketPriority.MEDIUM, + None, + ), + ( + customer3.id, + "Billing issue", + "I was charged twice for order ORD-006. Please refund one charge.", + TicketPriority.URGENT, + orders[5].id, + ), + ( + customers[3].id, + "Technical support", + "The product I received is not working. It won't turn on.", + TicketPriority.HIGH, + orders[7].id, + ), + ( + customers[4].id, + "Shipping delay", + ( + "My order ORD-009 was supposed to arrive yesterday but " + "hasn't shown up." + ), + TicketPriority.MEDIUM, + orders[8].id, + ), + ( + customers[5].id, + "Return request", + "I want to return order ORD-010. The item doesn't fit.", + TicketPriority.LOW, + orders[9].id, + ), + ( + customers[6].id, + "Account question", + "How do I change my email address?", + TicketPriority.LOW, + None, + ), + ( + customers[7].id, + "Payment failed", + "My payment for order ORD-012 failed. Can you help?", + TicketPriority.HIGH, + orders[11].id, + ), + ( + customers[8].id, + "Refund status", + "I returned order ORD-013 last week. When will I get my refund?", + TicketPriority.MEDIUM, + orders[12].id, + ), + ( + customers[9].id, + "Product inquiry", + "Do you have this product in a different color?", + TicketPriority.LOW, + None, + ), + ] + + tickets = [] + for customer_id, subject, desc, priority, order_id in tickets_data: + tickets.append( + create_ticket(db, customer_id, subject, desc, priority, order_id) + ) + + # Create sample knowledge base articles + kb_articles = [ + KnowledgeBase( + title="Return Policy", + content=( + "Our return policy allows returns within 30 days of purchase. " + "Items must be in original condition with tags attached. " + "Electronics can be returned for a full refund or exchange. " + "To initiate a return, log into your account and go to the " + "Orders section, then click 'Return' next to the item you " + "wish to return." + ), + category="policies", + tags="returns,refund,policy,exchange", + ), + KnowledgeBase( + title="Shipping Information", + content=( + "Standard shipping takes 5-7 business days. Express shipping " + "(2-3 days) and overnight shipping are available at checkout. " + "You will receive a tracking number via email once your order " + "ships. International shipping is available to most countries " + "with delivery times of 10-21 business days." + ), + category="shipping", + tags="shipping,delivery,tracking,international", + ), + KnowledgeBase( + title="Account Management", + content=( + "You can manage your account by logging in at our website. " + "From there you can view order history, update your address, " + "change your password, and manage payment methods. To update " + "your email, go to Account Settings > Profile Information." + ), + category="account", + tags="account,login,profile,settings", + ), + KnowledgeBase( + title="Order Tracking", + content=( + "To track your order, use the order number provided in your " + "confirmation email. You can also log into your account to " + "see all order statuses and tracking information. If your " + "order shows as 'Shipped', click on the tracking number to " + "see real-time updates from the carrier." + ), + category="orders", + tags="tracking,order,status,delivery", + ), + KnowledgeBase( + title="Payment Methods", + content=( + "We accept all major credit cards (Visa, MasterCard, " + "American Express), PayPal, and Apple Pay. All payments are " + "processed securely. We do not store your full credit card " + "information. For security, we use SSL encryption for all " + "transactions." + ), + category="billing", + tags="payment,credit card,paypal,security", + ), + KnowledgeBase( + title="Technical Support", + content=( + "For technical issues with products, please check the " + "troubleshooting guide in the product manual. If the issue " + "persists, contact our technical support team with your order " + "number and a description of the problem. Our support team " + "is available Monday-Friday, 9 AM - 6 PM EST." + ), + category="support", + tags="technical,troubleshooting,help,contact", + ), + KnowledgeBase( + title="Refund Processing", + content=( + "Refunds are typically processed within 5-7 business days " + "after we receive your returned item. The refund will be " + "issued to the original payment method. You will receive an " + "email confirmation once the refund has been processed. For " + "credit cards, it may take an additional 3-5 business days " + "to appear on your statement." + ), + category="billing", + tags="refund,return,payment,processing", + ), + KnowledgeBase( + title="Order Cancellation", + content=( + "You can cancel an order within 24 hours of placing it if " + "it hasn't shipped yet. To cancel, go to your Orders page " + "and click 'Cancel Order'. If your order has already shipped, " + "you'll need to return it using our standard return process. " + "Cancelled orders are refunded immediately." + ), + category="orders", + tags="cancel,order,refund,policy", + ), + KnowledgeBase( + title="Product Warranty", + content=( + "All products come with a manufacturer's warranty. " + "Electronics typically have a 1-year warranty covering " + "defects in materials and workmanship. Warranty claims must " + "be made within the warranty period and require proof of " + "purchase. Contact support with your order number to " + "initiate a warranty claim." + ), + category="support", + tags="warranty,electronics,defect,claim", + ), + KnowledgeBase( + title="Gift Cards and Promo Codes", + content=( + "Gift cards can be applied at checkout by entering the code " + "in the 'Promo Code' field. Gift cards never expire and can " + "be used for any purchase. Promo codes are case-sensitive " + "and may have expiration dates or minimum purchase " + "requirements. Only one promo code can be used per order." + ), + category="billing", + tags="gift card,promo code,discount,coupon", + ), + KnowledgeBase( + title="Shipping Address Changes", + content=( + "You can change your shipping address for an order if it " + "hasn't shipped yet. Log into your account, go to Orders, " + "and click 'Change Address' next to the order. Once an " + "order has shipped, address changes must be made directly " + "with the carrier using your tracking number." + ), + category="orders", + tags="address,shipping,change,update", + ), + KnowledgeBase( + title="Damaged or Defective Items", + content=( + "If you receive a damaged or defective item, please contact " + "us within 48 hours of delivery. We'll arrange for a " + "replacement or full refund at no cost to you. Please " + "include photos of the damage when contacting support. We " + "may request that you return the item for inspection." + ), + category="support", + tags="damaged,defective,replacement,refund", + ), + ] + + for article in kb_articles: + db.add(article) + + db.commit() + logger.info( + f"Database seeded successfully: {len(customers)} customers, " + f"{len(orders)} orders, {len(tickets)} tickets, " + f"{len(kb_articles)} KB articles" + ) + + except Exception as e: + db.rollback() + logger.error(f"Error seeding database: {e}", exc_info=True) + raise + finally: + db.close() + + +if __name__ == "__main__": + # Initialize database when run directly + init_database(seed_data=True) diff --git a/examples/customer_support_agent/src/agent/graph.py b/examples/customer_support_agent/src/agent/graph.py new file mode 100644 index 0000000..c56b0c8 --- /dev/null +++ b/examples/customer_support_agent/src/agent/graph.py @@ -0,0 +1,473 @@ +"""LangGraph agent definition for customer support.""" + +from typing import Literal +from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage +from langgraph.graph import StateGraph, END +from langgraph.prebuilt import ToolNode +from langgraph.checkpoint.memory import MemorySaver + +from src.agent.state import AgentState +from src.agent.llm import get_llm +from src.agent.mcp_tools import create_mcp_tools +from src.agent.tools.web_search import web_search_tool, web_search_news_tool +from src.agent.tools.email import send_email_tool +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +# Collect all tools (MCP + Direct) +def get_tools(): + """Get all available tools: MCP tools + direct tools.""" + # Get MCP tools (database operations) + mcp_tools = create_mcp_tools() + + # Direct tools (no database access) + direct_tools = [ + web_search_tool, + web_search_news_tool, + send_email_tool, + ] + + return mcp_tools + direct_tools + + +def start_node(state: AgentState) -> AgentState: + """Seed conversation with a single system prompt and ensure defaults.""" + messages = state.get("messages", []) + outputs: dict = {} + + if not any(isinstance(msg, SystemMessage) for msg in messages): + system_prompt = SystemMessage( + content=( + "You are a professional customer support AI assistant for " + "Highflame. Be direct and helpful - only greet on the very " + "first message of a conversation. " + "\n\nYour capabilities include:" + "\n1. Order Management: Look up order details, check order " + "status, view order history" + "\n2. Customer Support: Find customer information, view " + "customer profiles, create new customers" + "\n3. Ticket Management: Create, update, and check support " + "tickets" + "\n4. Knowledge Base: Search help articles and FAQs" + "\n5. Web Search: Find current information from the web" + "\n6. Email: Send emails to customers" + "\n\nCRITICAL RULES:" + "\n1. You must ALWAYS use tools to retrieve factual " + "information. Never guess or fabricate data." + "\n2. MEMORY AND CONVERSATION HISTORY: You have access to " + "the FULL conversation history in the messages you receive. " + "When a user asks about previous questions, messages, or " + "conversation context, you MUST look through the " + "HumanMessage instances in the conversation history and " + "answer based on what was actually said. DO NOT default to " + "listing capabilities when asked about conversation history." + "\n3. When asked 'what questions did I ask', 'what did I ask " + "you', 'last questions', or similar queries about the " + "conversation, you MUST examine the HumanMessage objects in " + "the conversation history and list the actual questions the " + "user asked." + "\n4. When asked about your capabilities (e.g., 'what can you " + "do'), explain what you can do based on the tools available " + "to you." + "\n5. After receiving tool results, provide a clear, " + "concise summary to the user." + "\n6. Always use the conversation history to provide " + "context-aware responses. The messages array contains the " + "full conversation." + ) + ) + outputs["messages"] = [system_prompt] + + # Ensure optional fields are initialized so downstream nodes can rely on them + if "customer_id" not in state: + outputs["customer_id"] = None + if "current_ticket_id" not in state: + outputs["current_ticket_id"] = None + if "tool_calls" not in state: + outputs["tool_calls"] = [] + if "context" not in state: + outputs["context"] = {} + if "intent" not in state: + outputs["intent"] = None + if "confidence" not in state: + outputs["confidence"] = 0.5 + if "should_escalate" not in state: + outputs["should_escalate"] = False + + return outputs + + +def understand_intent_node(state: AgentState) -> AgentState: + """Classify customer query intent with improved categorization.""" + logger.debug("Understanding intent node") + llm = get_llm() + messages = state.get("messages", []) + + # Get the last user message + last_message = "" + for msg in reversed(messages): + if isinstance(msg, HumanMessage): + last_message = msg.content + break + + logger.debug( + f"Classifying intent for message: {last_message[:100]}" + ) + intent_prompt = f"""Analyze this customer support message and classify its primary intent. # noqa: E501 + +Message: "{last_message}" + +Intent categories: +- order_inquiry: Questions about specific orders, shipping status, tracking, delivery +- account: Customer account info, profile, login, registration +- ticket_management: Creating, updating, or checking support tickets +- billing: Payments, refunds, invoices, pricing questions +- technical_support: Product issues, troubleshooting, how-to questions +- general: General questions, policies, FAQs, product information + +Rules: +1. Choose the MOST SPECIFIC category that fits +2. If message contains order numbers (ORD-XXX), likely "order_inquiry" +3. If asking about policies/procedures, likely "general" +4. If reporting a problem, likely "technical_support" or "ticket_management" + +Respond with ONLY the category name, nothing else.""" + + response = llm.invoke([HumanMessage(content=intent_prompt)]) + intent = response.content.strip().lower() + + # Validate and set confidence + valid_intents = { + "order_inquiry", + "account", + "ticket_management", + "billing", + "technical_support", + "general", + } + + confidence = 0.9 if intent in valid_intents else 0.5 + if intent not in valid_intents: + intent = "general" # Default fallback + + return { + "intent": intent, + "confidence": confidence, + } + + +def call_tools_node(state: AgentState) -> AgentState: + """Invoke appropriate tools based on intent.""" + logger.debug("Call tools node") + llm = get_llm() + tools = get_tools() + + # Bind tools to LLM + llm_with_tools = llm.bind_tools(tools) + + messages = state.get("messages", []) + if not messages: + logger.warning("No messages in state for call_tools_node") + return {} + + # Log conversation context for debugging + from langchain_core.messages import HumanMessage + human_messages = [ + msg.content for msg in messages if isinstance(msg, HumanMessage) + ] + logger.debug( + f"Calling tools with {len(messages)} total messages, " + f"{len(human_messages)} human messages" + ) + if len(human_messages) > 1: + logger.debug(f"Previous human messages: {human_messages[:-1]}") + + logger.debug(f"Invoking LLM with {len(tools)} tools available") + response = llm_with_tools.invoke(messages) + + if hasattr(response, "tool_calls") and response.tool_calls: + existing_calls = state.get("tool_calls", []) + new_calls = [ + { + "tool": tool_call["name"], + "args": tool_call.get("args", {}), + "id": tool_call.get("id", ""), + } + for tool_call in response.tool_calls + ] + tool_names = [tc['name'] for tc in response.tool_calls] + logger.info( + f"LLM requested {len(response.tool_calls)} tool calls: " + f"{tool_names}" + ) + return { + "messages": [response], + "tool_calls": existing_calls + new_calls, + } + + # No tool calls made - return the response + logger.debug("LLM responded without tool calls") + if hasattr(response, "content") and response.content: + logger.debug(f"LLM response preview: {response.content[:200]}") + return {"messages": [response]} + + +def synthesize_response_node(state: AgentState) -> AgentState: + """Generate final response from tool results.""" + logger.debug("Synthesize response node") + llm = get_llm() + messages = state.get("messages", []) + if not messages: + logger.warning("No messages in state for synthesize_response_node") + return {} + + # Keep ALL messages including system message for full context + # This ensures the LLM has access to conversation history + llm_messages = [msg for msg in messages if hasattr(msg, "content")] + + # Log message types for debugging memory + from langchain_core.messages import HumanMessage, AIMessage, ToolMessage + human_count = sum( + 1 for msg in llm_messages if isinstance(msg, HumanMessage) + ) + ai_count = sum(1 for msg in llm_messages if isinstance(msg, AIMessage)) + tool_count = sum( + 1 for msg in llm_messages if isinstance(msg, ToolMessage) + ) + logger.debug( + f"Synthesizing response from {len(llm_messages)} messages " + f"(Human: {human_count}, AI: {ai_count}, Tool: {tool_count})" + ) + + # Log ALL human messages for debugging memory + all_human = [ + msg.content for msg in llm_messages if isinstance(msg, HumanMessage) + ] + if all_human: + logger.debug( + f"All human messages in context ({len(all_human)} total): " + f"{all_human}" + ) + if len(all_human) > 1: + logger.info( + f"Conversation history available: {len(all_human)} human " + "messages in context" + ) + logger.debug(f"Previous questions: {all_human[:-1]}") + + response = llm.invoke(llm_messages) + resp_len = ( + len(response.content) if hasattr(response, 'content') else 0 + ) + logger.debug(f"Generated response length: {resp_len}") + + confidence = state.get("confidence", 0.5) + if hasattr(response, "content") and response.content: + content_lower = response.content.lower() + if "I don't know" in response.content or "i'm not sure" in content_lower: + confidence = min(confidence, 0.4) + logger.debug( + f"Low confidence detected, adjusted to {confidence}" + ) + + return { + "messages": [response], + "confidence": confidence, + } + + +def escalate_node(state: AgentState) -> AgentState: + """Handle escalation to human agent.""" + escalation_message = AIMessage( + content=( + "I understand this is a complex issue. Let me connect you with " + "a human support agent who can provide more specialized " + "assistance. Your ticket has been created and you'll receive " + "an email confirmation shortly." + ) + ) + return {"messages": [escalation_message], "should_escalate": True} + + +def should_continue_after_call_tools( # noqa: C901 + state: AgentState, +) -> Literal["tools", "synthesize", "escalate", "end"]: + """Determine next step after call_tools node.""" + messages = state.get("messages", []) + if not messages: + logger.debug("No messages, ending") + return "end" + + last_message = messages[-1] + + # Check if we need to escalate + if state.get("should_escalate", False): + logger.info("Escalation requested") + return "escalate" + + # Check if last message has tool calls - route to tools node + if hasattr(last_message, "tool_calls") and last_message.tool_calls: + num_calls = len(last_message.tool_calls) + logger.debug(f"Routing to tools node ({num_calls} tool calls)") + return "tools" + + # Check confidence for escalation + confidence = state.get("confidence", 0.5) + if confidence < 0.3: + logger.warning(f"Low confidence ({confidence}), escalating") + return "escalate" + + # If no tool calls and model provided a response, check if we need to + # synthesize. Only synthesize if we have UNSYNTHESIZED tool results to + # incorporate + if isinstance(last_message, AIMessage) and last_message.content: + # If the last message is an AIMessage without tool_calls, it means + # the LLM answered directly. Check if there are any tool messages + # AFTER this AI message (which would be impossible) OR if there are + # tool messages that haven't been followed by a synthesized response + + # Find the last AI message with tool_calls (if any) + last_tool_call_ai_idx = -1 + for i in range(len(messages) - 1, -1, -1): + msg = messages[i] + if (isinstance(msg, AIMessage) and hasattr(msg, "tool_calls") + and msg.tool_calls): + last_tool_call_ai_idx = i + break + + # If there's a tool-calling AI message, check if there are tool + # messages after it that haven't been synthesized yet (i.e., no AI + # message after the tool messages) + has_unsynthesized_tool_messages = False + if last_tool_call_ai_idx >= 0: + # Find tool messages after the tool-calling AI message + tool_message_indices = [] + for i in range(last_tool_call_ai_idx + 1, len(messages)): + msg = messages[i] + if isinstance(msg, ToolMessage): + tool_message_indices.append(i) + + # Check if there's an AI message after the tool messages + # (synthesized) + if tool_message_indices: + last_tool_idx = tool_message_indices[-1] + # Check if there's an AI message after the last tool message + has_synthesized_response = False + for i in range(last_tool_idx + 1, len(messages)): + msg = messages[i] + if (isinstance(msg, AIMessage) and not + (hasattr(msg, "tool_calls") and msg.tool_calls)): + has_synthesized_response = True + break + + if not has_synthesized_response: + # Tool messages exist but haven't been synthesized yet + has_unsynthesized_tool_messages = True + logger.debug( + f"Found unsynthesized tool messages after index " + f"{last_tool_call_ai_idx}" + ) + + if not has_unsynthesized_tool_messages: + # No unsynthesized tool messages - LLM answered directly + logger.info("Model provided direct answer without tools, ending") + return "end" + + # If we have tool results, synthesize them + logger.debug("Routing to synthesize node to incorporate tool results") + return "synthesize" + + +def should_continue_after_synthesize( + state: AgentState, +) -> Literal["call_tools", "escalate", "end"]: + """Determine next step after synthesize node.""" + messages = state.get("messages", []) + if not messages: + return "end" + + last_message = messages[-1] + + # Check if we need to escalate + if state.get("should_escalate", False): + return "escalate" + + # Check if last message has tool calls - need to call tools again + if hasattr(last_message, "tool_calls") and last_message.tool_calls: + return "call_tools" + + # Check confidence for escalation + confidence = state.get("confidence", 0.5) + if confidence < 0.3: + return "escalate" + + # Default to ending conversation + return "end" + + +def create_agent_graph(): + """Create and compile the LangGraph agent.""" + logger.info("Creating agent graph") + # Create tool node + tools = get_tools() + tool_node = ToolNode(tools) + logger.debug("Tool node created") + + # Create graph + workflow = StateGraph(AgentState) + + # Add nodes + workflow.add_node("start", start_node) + workflow.add_node("understand_intent", understand_intent_node) + workflow.add_node("call_tools", call_tools_node) + workflow.add_node("tools", tool_node) + workflow.add_node("synthesize", synthesize_response_node) + workflow.add_node("escalate", escalate_node) + + # Set entry point + workflow.set_entry_point("start") + + # Add edges + workflow.add_edge("start", "understand_intent") + workflow.add_edge("understand_intent", "call_tools") + workflow.add_conditional_edges( + "call_tools", + should_continue_after_call_tools, + { + "tools": "tools", + "synthesize": "synthesize", + "escalate": "escalate", + "end": END, + }, + ) + # Tools node outputs tool messages, which should go directly to synthesize + # But we need to ensure the message sequence is valid + workflow.add_edge("tools", "synthesize") + workflow.add_conditional_edges( + "synthesize", + should_continue_after_synthesize, + { + "call_tools": "call_tools", + "escalate": "escalate", + "end": END, + }, + ) + workflow.add_edge("escalate", END) + + # Add memory for conversation persistence + memory = MemorySaver() + + # Compile graph + logger.debug("Compiling agent graph with memory") + app = workflow.compile(checkpointer=memory) + logger.info("Agent graph compiled successfully") + + return app + + +# Create the agent instance +def get_agent(): + """Get the compiled agent graph.""" + return create_agent_graph() diff --git a/examples/customer_support_agent/src/agent/llm.py b/examples/customer_support_agent/src/agent/llm.py new file mode 100644 index 0000000..f7caa3a --- /dev/null +++ b/examples/customer_support_agent/src/agent/llm.py @@ -0,0 +1,72 @@ +"""Highflame LLM integration (unified provider for OpenAI and Google/Gemini).""" + +import os +from langchain_openai import ChatOpenAI +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +def get_llm(temperature: float = 0.8): + """ + Get configured LLM via Highflame. + + Args: + provider: Ignored - always uses Highflame. Kept for API compatibility. + temperature: Temperature for sampling (default: 0.8) + + Returns: + Configured LLM instance via Highflame + + Raises: + ValueError: If API key not set + """ + logger.debug(f"Getting LLM instance - temperature: {temperature}") + + # Highflame integration (always uses Highflame) + highflame_api_key = os.getenv("HIGHFLAME_API_KEY") + if not highflame_api_key: + logger.error("HIGHFLAME_API_KEY not set in environment variables") + raise ValueError("HIGHFLAME_API_KEY not set in environment variables") + + # Read route and model from env vars + route = os.getenv("HIGHFLAME_ROUTE", "").strip() + model = os.getenv("MODEL", "").strip() + llm_api_key = os.getenv("LLM_API_KEY", "").strip() + + # Route is required + if not route: + logger.error("HIGHFLAME_ROUTE not set in environment variables") + raise ValueError( + "HIGHFLAME_ROUTE must be set (e.g., 'openai' or 'google')" + ) + + # Model is required + if not model: + logger.error("MODEL not set in environment variables") + raise ValueError( + "MODEL must be set (e.g., 'gpt-4o-mini' or " + "'gemini-2.5-flash-lite')" + ) + + # LLM API key is required + if not llm_api_key: + logger.error("LLM_API_KEY not set in environment variables") + raise ValueError( + "LLM_API_KEY must be set (OpenAI API key for 'openai' route, " + "Gemini API key for 'google' route)" + ) + + logger.info(f"Initializing Highflame LLM - route: {route}, model: {model}") + + # Highflame provides a unified OpenAI-compatible API for both OpenAI and Gemini + return ChatOpenAI( + model=model, + base_url="https://api.highflame.app/v1", + api_key=llm_api_key, + default_headers={ + "X-Javelin-apikey": highflame_api_key, + "X-Javelin-route": route, + }, + temperature=temperature, + ) diff --git a/examples/customer_support_agent/src/agent/mcp_tools.py b/examples/customer_support_agent/src/agent/mcp_tools.py new file mode 100644 index 0000000..eb94794 --- /dev/null +++ b/examples/customer_support_agent/src/agent/mcp_tools.py @@ -0,0 +1,289 @@ +"""MCP client wrapper for database tools.""" + +import os +import asyncio +from fastmcp import Client +from langchain_core.tools import StructuredTool +from pydantic import BaseModel, Field +from typing import Optional, List +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +# MCP client singleton +_mcp_client = None + + +def get_mcp_client(): + """Get or create MCP client.""" + global _mcp_client + if _mcp_client is None: + # FastMCP defaults to port 8000 with SSE transport + mcp_url = os.getenv("MCP_SERVER_URL", "http://0.0.0.0:9000/mcp") + _mcp_client = Client(mcp_url) + return _mcp_client + + +async def call_mcp_tool_async(tool_name: str, **kwargs): + """Call MCP tool asynchronously.""" + logger.debug(f"Calling MCP tool: {tool_name} with args: {kwargs}") + client = get_mcp_client() + try: + async with client: + result = await client.call_tool(tool_name, kwargs) + if not result: + logger.error(f"MCP tool {tool_name} returned None result") + return "Error: Tool returned no result" + + if not result.content: + logger.warning(f"MCP tool {tool_name} returned empty content") + return "" + + if len(result.content) == 0: + logger.warning( + f"MCP tool {tool_name} returned empty content list" + ) + return "" + + first_content = result.content[0] + if not hasattr(first_content, 'text'): + logger.error( + f"MCP tool {tool_name} content missing text attribute" + ) + return ( + f"Error: Unexpected response format from tool {tool_name}" + ) + + response_text = first_content.text if first_content.text else "" + logger.debug( + f"MCP tool {tool_name} returned response " + f"(length: {len(response_text)})" + ) + return response_text + except Exception as e: + logger.error(f"Error calling MCP tool {tool_name}: {e}", exc_info=True) + raise + + +def call_mcp_tool(tool_name: str, **kwargs): + """Call MCP tool synchronously.""" + return asyncio.run(call_mcp_tool_async(tool_name, **kwargs)) + + +# Pydantic Schemas for Tools + + +class SearchKnowledgeBaseSchema(BaseModel): + """Search knowledge base input schema.""" + + query: str = Field( + ..., description="The search query or keywords to search for" + ) + category: Optional[str] = Field( + None, description="Optional category to filter results" + ) + + +class GetKnowledgeBaseByCategorySchema(BaseModel): + """Get knowledge base by category input schema.""" + + category: str = Field(..., description="The category to retrieve articles from") + + +class LookupOrderSchema(BaseModel): + """Lookup order input schema.""" + + order_number: str = Field( + ..., description="The order number to look up (e.g., 'ORD-001')" + ) + customer_id: Optional[int] = Field( + None, description="Optional customer ID to verify" + ) + + +class GetOrderStatusSchema(BaseModel): + """Get order status input schema.""" + + order_number: str = Field(..., description="The order number to check") + + +class GetOrderHistorySchema(BaseModel): + """Get order history input schema.""" + + customer_id: int = Field(..., description="The ID of the customer") + + +class LookupCustomerSchema(BaseModel): + """Lookup customer input schema.""" + + email: Optional[str] = Field(None, description="Customer email address") + phone: Optional[str] = Field(None, description="Customer phone number") + customer_id: Optional[int] = Field(None, description="Customer ID") + + +class GetCustomerProfileSchema(BaseModel): + """Get customer profile input schema.""" + + customer_id: int = Field(..., description="The ID of the customer") + + +class CreateCustomerSchema(BaseModel): + """Create customer input schema.""" + + name: str = Field(..., description="Customer's full name") + email: str = Field(..., description="Customer's email address") + phone: Optional[str] = Field(None, description="Optional phone number") + + +class CreateTicketSchema(BaseModel): + """Create ticket input schema.""" + + customer_id: int = Field( + ..., description="The ID of the customer creating the ticket" + ) + subject: str = Field(..., description="Brief subject line for the ticket") + description: str = Field(..., description="Detailed description of the issue") + priority: str = Field( + "medium", + description="Priority level - 'low', 'medium', 'high', or 'urgent'", + ) + order_id: Optional[int] = Field( + None, description="Optional order ID if related to a specific order" + ) + + +class UpdateTicketSchema(BaseModel): + """Update ticket input schema.""" + + ticket_id: int = Field(..., description="The ID of the ticket to update") + status: Optional[str] = Field( + None, + description=( + "New status - 'open', 'in_progress', 'resolved', or 'closed'" + ), + ) + notes: Optional[str] = Field(None, description="Additional notes or updates") + + +class GetTicketSchema(BaseModel): + """Get ticket input schema.""" + + ticket_id: int = Field(..., description="The ID of the ticket to retrieve") + + +# Create LangChain Tools + + +def create_mcp_tools() -> List[StructuredTool]: + """Create LangChain tools that wrap MCP server tools.""" + + # Knowledge Base Tools + search_knowledge_base_tool = StructuredTool( + name="search_knowledge_base_tool", + description=( + "Search the knowledge base for information relevant to the " + "customer's query." + ), + args_schema=SearchKnowledgeBaseSchema, + func=lambda **kwargs: call_mcp_tool( + "search_knowledge_base_tool", **kwargs + ), + ) + + get_knowledge_base_by_category_tool = StructuredTool( + name="get_knowledge_base_by_category_tool", + description=( + "Get all knowledge base articles in a specific category." + ), + args_schema=GetKnowledgeBaseByCategorySchema, + func=lambda **kwargs: call_mcp_tool( + "get_knowledge_base_by_category_tool", **kwargs + ), + ) + + # Order Tools + lookup_order_tool = StructuredTool( + name="lookup_order_tool", + description="Look up order details by order number.", + args_schema=LookupOrderSchema, + func=lambda **kwargs: call_mcp_tool("lookup_order_tool", **kwargs), + ) + + get_order_status_tool = StructuredTool( + name="get_order_status_tool", + description="Get the current status of an order.", + args_schema=GetOrderStatusSchema, + func=lambda **kwargs: call_mcp_tool("get_order_status_tool", **kwargs), + ) + + get_order_history_tool = StructuredTool( + name="get_order_history_tool", + description="Get all orders for a specific customer.", + args_schema=GetOrderHistorySchema, + func=lambda **kwargs: call_mcp_tool("get_order_history_tool", **kwargs), + ) + + # Customer Tools + lookup_customer_tool = StructuredTool( + name="lookup_customer_tool", + description="Look up a customer by email, phone number, or customer ID.", + args_schema=LookupCustomerSchema, + func=lambda **kwargs: call_mcp_tool("lookup_customer_tool", **kwargs), + ) + + get_customer_profile_tool = StructuredTool( + name="get_customer_profile_tool", + description=( + "Get full customer profile including order history and tickets." + ), + args_schema=GetCustomerProfileSchema, + func=lambda **kwargs: call_mcp_tool( + "get_customer_profile_tool", **kwargs + ), + ) + + create_customer_tool = StructuredTool( + name="create_customer_tool", + description="Create a new customer record.", + args_schema=CreateCustomerSchema, + func=lambda **kwargs: call_mcp_tool("create_customer_tool", **kwargs), + ) + + # Ticket Tools + create_ticket_tool = StructuredTool( + name="create_ticket_tool", + description="Create a new support ticket for a customer.", + args_schema=CreateTicketSchema, + func=lambda **kwargs: call_mcp_tool("create_ticket_tool", **kwargs), + ) + + update_ticket_tool = StructuredTool( + name="update_ticket_tool", + description="Update an existing support ticket.", + args_schema=UpdateTicketSchema, + func=lambda **kwargs: call_mcp_tool("update_ticket_tool", **kwargs), + ) + + get_ticket_tool = StructuredTool( + name="get_ticket_tool", + description="Retrieve details of a specific support ticket.", + args_schema=GetTicketSchema, + func=lambda **kwargs: call_mcp_tool("get_ticket_tool", **kwargs), + ) + + tools = [ + search_knowledge_base_tool, + get_knowledge_base_by_category_tool, + lookup_order_tool, + get_order_status_tool, + get_order_history_tool, + lookup_customer_tool, + get_customer_profile_tool, + create_customer_tool, + create_ticket_tool, + update_ticket_tool, + get_ticket_tool, + ] + logger.info(f"Created {len(tools)} MCP tool wrappers") + return tools diff --git a/examples/customer_support_agent/src/agent/state.py b/examples/customer_support_agent/src/agent/state.py new file mode 100644 index 0000000..de1652f --- /dev/null +++ b/examples/customer_support_agent/src/agent/state.py @@ -0,0 +1,33 @@ +"""Agent state schema for LangGraph.""" + +from typing import TypedDict, List, Optional, Dict, Any, Annotated +from langchain_core.messages import BaseMessage +from langgraph.graph import add_messages + + +class AgentState(TypedDict): + """State schema for the customer support agent.""" + + # Conversation messages + messages: Annotated[List[BaseMessage], add_messages] + + # Customer identification + customer_id: Optional[int] + + # Current ticket being worked on + current_ticket_id: Optional[int] + + # History of tool calls + tool_calls: List[Dict[str, Any]] + + # Additional context + context: Dict[str, Any] + + # Intent classification + intent: Optional[str] + + # Confidence score for the current response + confidence: float + + # Whether to escalate to human + should_escalate: bool diff --git a/examples/customer_support_agent/src/agent/tools/__init__.py b/examples/customer_support_agent/src/agent/tools/__init__.py new file mode 100644 index 0000000..c38d254 --- /dev/null +++ b/examples/customer_support_agent/src/agent/tools/__init__.py @@ -0,0 +1 @@ +"""Tools for the customer support agent.""" diff --git a/examples/customer_support_agent/src/agent/tools/email.py b/examples/customer_support_agent/src/agent/tools/email.py new file mode 100644 index 0000000..a965bb2 --- /dev/null +++ b/examples/customer_support_agent/src/agent/tools/email.py @@ -0,0 +1,78 @@ +"""Email sending tool.""" + +from langchain_core.tools import tool +import os +import smtplib +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart +from typing import Optional + + +def send_email_smtp( + to_email: str, + subject: str, + body: str, + smtp_server: Optional[str] = None, + smtp_port: Optional[int] = None, + smtp_username: Optional[str] = None, + smtp_password: Optional[str] = None, + from_email: Optional[str] = None, +) -> bool: + """Send email using SMTP.""" + smtp_server = smtp_server or os.getenv("SMTP_SERVER", "smtp.gmail.com") + smtp_port = smtp_port or int(os.getenv("SMTP_PORT", "587")) + smtp_username = smtp_username or os.getenv("SMTP_USERNAME") + smtp_password = smtp_password or os.getenv("SMTP_PASSWORD") + from_email = from_email or os.getenv("EMAIL_FROM") or smtp_username + + if not smtp_username or not smtp_password: + raise ValueError( + "SMTP credentials not configured. " + "Set SMTP_USERNAME and SMTP_PASSWORD in .env" + ) + + try: + msg = MIMEMultipart() + msg["From"] = from_email + msg["To"] = to_email + msg["Subject"] = subject + msg.attach(MIMEText(body, "plain")) + + with smtplib.SMTP(smtp_server, smtp_port) as server: + server.starttls() + server.login(smtp_username, smtp_password) + server.send_message(msg) + + return True + except Exception as e: + raise Exception(f"Failed to send email: {str(e)}") + + +@tool +def send_email_tool(to: str, subject: str, body: str) -> str: + """ + Send an email notification to a customer. + + Args: + to: Recipient email address + subject: Email subject line + body: Email body content + + Returns: + Confirmation message if email was sent successfully, or an error message. + """ + try: + send_email_smtp(to, subject, body) + return f"Email sent successfully to {to}\n" f"Subject: {subject}" + except ValueError as e: + return ( + f"Email configuration error: {str(e)}\n" + "Please configure SMTP settings in your .env file:\n" + "- SMTP_SERVER\n" + "- SMTP_PORT\n" + "- SMTP_USERNAME\n" + "- SMTP_PASSWORD\n" + "- EMAIL_FROM" + ) + except Exception as e: + return f"Error sending email: {str(e)}" diff --git a/examples/customer_support_agent/src/agent/tools/web_search.py b/examples/customer_support_agent/src/agent/tools/web_search.py new file mode 100644 index 0000000..858a3b8 --- /dev/null +++ b/examples/customer_support_agent/src/agent/tools/web_search.py @@ -0,0 +1,90 @@ +"""Web search tool using DuckDuckGo.""" + +from langchain_core.tools import tool + +try: + # Try new package name first + from ddgs import DDGS + + DDGS_AVAILABLE = True +except ImportError: + try: + # Fallback to old package name + from duckduckgo_search import DDGS + + DDGS_AVAILABLE = True + except ImportError: + DDGS_AVAILABLE = False + + +@tool +def web_search_tool(query: str, max_results: int = 5) -> str: + """ + Search the web for current information that may not be in the knowledge base. + + Args: + query: The search query + max_results: Maximum number of results to return (default: 5) + + Returns: + A formatted string with search results, or an error message if search fails. + """ + if not DDGS_AVAILABLE: + return ( + "Error: DuckDuckGo search is not available. " + "Install it with: pip install duckduckgo-search" + ) + + try: + with DDGS() as ddgs: + results = list(ddgs.text(query, max_results=max_results)) + + if not results: + return f"No results found for query: '{query}'" + + result_text = f"Web Search Results for '{query}':\n\n" + for i, result in enumerate(results, 1): + result_text += f"{i}. {result.get('title', 'No title')}\n" + result_text += f" URL: {result.get('href', 'No URL')}\n" + result_text += f" {result.get('body', 'No description')[:200]}...\n\n" + + return result_text + except Exception as e: + return f"Error performing web search: {str(e)}" + + +@tool +def web_search_news_tool(query: str, max_results: int = 5) -> str: + """ + Search for recent news articles related to the query. + + Args: + query: The search query + max_results: Maximum number of results to return (default: 5) + + Returns: + A formatted string with news search results. + """ + if not DDGS_AVAILABLE: + return ( + "Error: DuckDuckGo search is not available. " + "Install it with: pip install duckduckgo-search" + ) + + try: + with DDGS() as ddgs: + results = list(ddgs.news(query, max_results=max_results)) + + if not results: + return f"No news results found for query: '{query}'" + + result_text = f"News Results for '{query}':\n\n" + for i, result in enumerate(results, 1): + result_text += f"{i}. {result.get('title', 'No title')}\n" + result_text += f" Source: {result.get('source', 'Unknown')}\n" + result_text += f" URL: {result.get('url', 'No URL')}\n" + result_text += f" {result.get('body', 'No description')[:200]}...\n\n" + + return result_text + except Exception as e: + return f"Error performing news search: {str(e)}" diff --git a/examples/customer_support_agent/src/api.py b/examples/customer_support_agent/src/api.py new file mode 100644 index 0000000..747de44 --- /dev/null +++ b/examples/customer_support_agent/src/api.py @@ -0,0 +1,417 @@ +"""FastAPI server for the customer support agent.""" + +import os +import sys +import warnings +from pathlib import Path +from typing import Optional, List +from contextlib import asynccontextmanager +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from dotenv import load_dotenv +from langchain_core.messages import HumanMessage + +# Suppress websockets deprecation warnings +warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="websockets" +) +warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + module="uvicorn.protocols.websockets", +) + +# Initialize logger early +from src.utils.logger import get_logger # noqa: E402 +logger = get_logger(__name__) + +# Add project root to path for imports +project_root = Path(__file__).resolve().parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +from src.agent.database.setup import init_database # noqa: E402 +from src.agent.graph import get_agent # noqa: E402 + + +# Load environment variables +load_dotenv() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan event handler for startup and shutdown.""" + # Startup + logger.info("Starting FastAPI application") + try: + logger.debug("Initializing database...") + init_database(seed_data=True) + logger.info("Database initialized successfully") + except Exception as e: + logger.error(f"Error initializing database: {e}", exc_info=True) + yield + # Shutdown + logger.info("Shutting down FastAPI application") + + +# Initialize FastAPI app +app = FastAPI( + title="Customer Support Agent API", + description="AI-powered customer support agent with LangGraph", + version="1.0.0", + lifespan=lifespan, +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Request/Response models +class ChatRequest(BaseModel): + """Request model for chat endpoint.""" + + message: str + thread_id: Optional[str] = "default" + customer_id: Optional[int] = None + + +class ChatResponse(BaseModel): + """Response model for chat endpoint.""" + + response: str + thread_id: str + tool_calls: Optional[List[dict]] = None + intent: Optional[str] = None + confidence: Optional[float] = None + + +class GenerateRequest(BaseModel): + """Request model for generate endpoint.""" + + message: str + thread_id: Optional[str] = None + customer_id: Optional[int] = None + + +class GenerateResponse(BaseModel): + """Response model for generate endpoint.""" + + response: str + thread_id: str + + +# Initialize agent (lazy loading) +_agent = None + + +def get_agent_instance(): + """Get or create agent instance.""" + global _agent + if _agent is None: + try: + logger.debug("Initializing agent instance") + _agent = get_agent() + logger.info("Agent instance created successfully") + except Exception as e: + logger.error(f"Failed to initialize agent: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Failed to initialize agent: {str(e)}", + ) + return _agent + + +def reset_agent(): + """Reset agent instance (for testing/debugging).""" + global _agent + _agent = None + + +@app.get("/") +async def root(): + """Root endpoint.""" + logger.debug("Root endpoint accessed") + return { + "message": "Customer Support Agent API", + "version": "1.0.0", + "endpoints": { + "/chat": ( + "POST - Chat with the agent (maintains conversation state)" + ), + "/generate": "POST - Generate a single response (no state)", + "/health": "GET - Health check", + }, + } + + +@app.get("/health") +async def health(): + """Health check endpoint.""" + logger.debug("Health check endpoint accessed") + return {"status": "healthy", "service": "customer-support-agent"} + + +@app.post("/chat", response_model=ChatResponse) +async def chat(request: ChatRequest): # noqa: C901 + """ + Chat endpoint that maintains conversation state. + + The agent remembers previous messages in the conversation thread, + allowing for multi-turn conversations. + """ + thread_id = request.thread_id or "default" + msg_len = len(request.message) + logger.info( + f"Chat request received - thread_id: {thread_id}, " + f"message length: {msg_len}" + ) + logger.debug(f"Chat message: {request.message[:200]}") + + try: + agent = get_agent_instance() + + # Create human message + human_message = HumanMessage(content=request.message) + + # Prepare input state + input_state = {"messages": [human_message]} + + # Configure thread ID for state management + config = {"configurable": {"thread_id": thread_id}} + + # Stream the agent response + logger.debug(f"Streaming agent response for thread: {thread_id}") + final_state = None + event_count = 0 + for event in agent.stream(input_state, config): + final_state = event + event_count += 1 + logger.debug(f"Agent event {event_count}: {list(event.keys())}") + + logger.debug(f"Agent completed with {event_count} events") + + # Extract response and metadata + response_text = "" + tool_calls_list = [] + intent = None + confidence = None + + if final_state: + # Look for response in synthesize node first + if "synthesize" in final_state: + messages = final_state["synthesize"].get("messages", []) + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content: + if not (hasattr(msg, "tool_calls") and msg.tool_calls): + response_text = msg.content + break + + # Extract metadata from synthesize node + if "tool_calls" in final_state["synthesize"]: + tool_calls_list = final_state["synthesize"].get("tool_calls", []) + if "confidence" in final_state["synthesize"]: + confidence = final_state["synthesize"].get("confidence") + + # Extract intent from understand_intent node + if "understand_intent" in final_state: + intent = final_state["understand_intent"].get("intent") + + # Fallback: check all nodes + if not response_text: + for node_name, node_output in final_state.items(): + if node_name in ["start", "understand_intent"]: + continue # Skip these nodes + messages = node_output.get("messages", []) + if messages: + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content: + if not (hasattr(msg, "tool_calls") and + msg.tool_calls): + # Skip generic greetings + greetings = [ + "How can I assist you today?", + "Hello! How can I help you?", + ] + if msg.content not in greetings: + response_text = msg.content + break + if response_text: + break + + if not response_text: + logger.warning( + f"No response text extracted for thread: {thread_id}" + ) + response_text = ( + "I apologize, but I couldn't generate a response. " + "Please try again." + ) + else: + tool_count = len(tool_calls_list) if tool_calls_list else 0 + logger.info( + f"Response generated - length: {len(response_text)}, " + f"tool_calls: {tool_count}" + ) + + return ChatResponse( + response=response_text, + thread_id=thread_id, + tool_calls=tool_calls_list if tool_calls_list else None, + intent=intent, + confidence=confidence, + ) + + except Exception as e: + logger.error(f"Error processing chat request: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Error processing chat request: {str(e)}", + ) + + +@app.post("/generate", response_model=GenerateResponse) +async def generate(request: GenerateRequest): # noqa: C901 + """ + Generate endpoint for single responses without maintaining state. + + Each request is treated as a new conversation. + """ + import uuid + + thread_id = request.thread_id or str(uuid.uuid4()) + msg_len = len(request.message) + logger.info( + f"Generate request received - thread_id: {thread_id}, " + f"message length: {msg_len}" + ) + logger.debug(f"Generate message: {request.message[:200]}") + + try: + agent = get_agent_instance() + + # Create human message + human_message = HumanMessage(content=request.message) + + # Prepare input state + input_state = {"messages": [human_message]} + + # Use a unique thread ID for each request (no state persistence) + config = {"configurable": {"thread_id": thread_id}} + + # Stream the agent response and collect all states + logger.debug(f"Streaming agent response for thread: {thread_id}") + all_states = [] + event_count = 0 + for event in agent.stream(input_state, config): + all_states.append(event) + event_count += 1 + logger.debug(f"Agent event {event_count}: {list(event.keys())}") + + logger.debug( + f"Agent completed with {event_count} events, " + f"{len(all_states)} states collected" + ) + + # Extract response - check all states in reverse order + response_text = "" + from langchain_core.messages import AIMessage, ToolMessage + + # Check if tools were called by looking for ToolMessage in any state + tools_called = False + for state in all_states: + if "tools" in state: + tool_messages = state["tools"].get("messages", []) + if any(isinstance(msg, ToolMessage) for msg in tool_messages): + tools_called = True + break + + # Priority 1: If tools were called, use synthesize node from final + # state + if tools_called and all_states: + final_state = all_states[-1] + if "synthesize" in final_state: + messages = final_state["synthesize"].get("messages", []) + for msg in reversed(messages): + if (isinstance(msg, AIMessage) and + hasattr(msg, "content") and msg.content): + if not (hasattr(msg, "tool_calls") and + msg.tool_calls): + response_text = msg.content + break + + # Priority 2: If no tools called, use call_tools node response + # (check all states) + if not response_text: + for state_idx, state in enumerate(reversed(all_states)): + if "call_tools" in state: + messages = state["call_tools"].get("messages", []) + for msg in reversed(messages): + if (isinstance(msg, AIMessage) and + hasattr(msg, "content") and msg.content): + if not (hasattr(msg, "tool_calls") and + msg.tool_calls): + content = msg.content.strip() + # Always use call_tools response if it exists + # and is not empty + if content: + response_text = content + break + if response_text: + break + + # Priority 3: Fallback to synthesize node + if not response_text and all_states: + final_state = all_states[-1] + if "synthesize" in final_state: + messages = final_state["synthesize"].get("messages", []) + for msg in reversed(messages): + if (isinstance(msg, AIMessage) and + hasattr(msg, "content") and msg.content): + if not (hasattr(msg, "tool_calls") and + msg.tool_calls): + content = msg.content.strip() + greetings = [ + "How can I assist you today?", + "Hello! How can I help you?", + ] + if content and content not in greetings: + response_text = content + break + + if not response_text: + logger.warning( + f"No response text extracted for thread: {thread_id}" + ) + response_text = ( + "I apologize, but I couldn't generate a response. " + "Please try again." + ) + else: + logger.info(f"Response generated - length: {len(response_text)}") + logger.debug(f"Response preview: {response_text[:200]}") + + return GenerateResponse(response=response_text, thread_id=thread_id) + + except Exception as e: + logger.error(f"Error processing generate request: {e}", exc_info=True) + import traceback + error_detail = ( + f"Error processing generate request: {str(e)}\n" + f"{traceback.format_exc()}" + ) + raise HTTPException(status_code=500, detail=error_detail) + + +if __name__ == "__main__": + import uvicorn + + port = int(os.getenv("PORT", 8000)) + logger.info(f"Starting FastAPI server on 0.0.0.0:{port}") + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/examples/customer_support_agent/src/main.py b/examples/customer_support_agent/src/main.py new file mode 100644 index 0000000..d67315f --- /dev/null +++ b/examples/customer_support_agent/src/main.py @@ -0,0 +1,417 @@ +"""Main entry point for the customer support agent.""" + +import os +import sys +import time +from pathlib import Path +from dotenv import load_dotenv +from langchain_core.messages import HumanMessage + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +# Import after path setup +from src.agent.database.setup import init_database # noqa: E402 +from src.agent.graph import get_agent # noqa: E402 + + +def load_environment(): + """Load environment variables from .env file.""" + env_path = os.path.join(os.path.dirname(__file__), "..", ".env") + if os.path.exists(env_path): + load_dotenv(env_path) + else: + # Try current directory + load_dotenv() + + # Check for required environment variables + highflame_key = os.getenv("HIGHFLAME_API_KEY") + llm_api_key = os.getenv("LLM_API_KEY") + + if not highflame_key: + print("Error: HIGHFLAME_API_KEY not found in environment variables.") + print("Please create a .env file with your Highflame API key:") + print("HIGHFLAME_API_KEY=your_key_here") + sys.exit(1) + + if not llm_api_key: + print("Error: LLM_API_KEY not found in environment variables.") + print("Please create a .env file with your LLM API key:") + print("LLM_API_KEY=your_key_here") + sys.exit(1) + + +def initialize_database(): + """Initialize the database with tables and sample data.""" + print("Initializing database...") + try: + init_database(seed_data=True) + print("✓ Database initialized successfully.\n") + except Exception as e: + print(f"Error initializing database: {e}") + sys.exit(1) + + +def print_separator(): + """Print a separator line.""" + print("\n" + "=" * 80 + "\n") + + +def print_message(role: str, content: str, tool_calls=None): + """Print a formatted message.""" + print(f"[{role.upper()}]") + print(content) + if tool_calls: + tools_used = ', '.join(tc.get('tool', 'unknown') for tc in tool_calls) + print(f"\n🔧 Tools Used: {tools_used}") + print() + + +def run_test_conversations(): # noqa: C901 + """Run automated test conversations to verify functionality.""" + print_separator() + print("🧪 AUTOMATED TEST SUITE - Multi-Turn Conversations") + print_separator() + + # Load environment and initialize + load_environment() + initialize_database() + + # Get the agent + try: + agent = get_agent() + print("✓ Agent initialized successfully\n") + except Exception as e: + print(f"✗ Error creating agent: {e}") + sys.exit(1) + + # Test scenarios + test_scenarios = [ + { + "name": "Order Lookup Test", + "thread_id": "test_order_lookup", + "messages": [ + "What is the status of order ORD-001?", + "Can you tell me the total amount for that order?", + "What was the last question I asked you?", + ], + "expected_tools": ["get_order_status_tool"], + }, + { + "name": "Customer Lookup Test", + "thread_id": "test_customer_lookup", + "messages": [ + "Who is customer john.doe@example.com?", + "Show me their order history", + "What did I ask you first?", + ], + "expected_tools": ["lookup_customer_tool", "get_order_history_tool"], + }, + { + "name": "Knowledge Base Test", + "thread_id": "test_knowledge_base", + "messages": [ + "What is your return policy?", + "How long does shipping take?", + "Can you remind me what I asked about shipping?", + ], + "expected_tools": ["search_knowledge_base_tool"], + }, + { + "name": "Ticket Creation Test", + "thread_id": "test_ticket_creation", + "messages": [ + "I need help with my order", + "My order number is ORD-002", + "Create a support ticket for this issue", + ], + "expected_tools": ["create_ticket_tool"], + }, + { + "name": "Memory Test", + "thread_id": "test_memory", + "messages": [ + "My name is Test User and my email is test@example.com", + "What is my name?", + "What is my email?", + "What were the first two things I told you?", + ], + "expected_tools": [], + }, + { + "name": "Multi-Tool Test", + "thread_id": "test_multi_tool", + "messages": [ + "I want to check order ORD-001 and also know the return policy", + "Can you create a ticket for order ORD-001?", + ], + "expected_tools": [ + "lookup_order_tool", + "search_knowledge_base_tool", + "create_ticket_tool", + ], + }, + ] + + # Run each test scenario + for scenario_idx, scenario in enumerate(test_scenarios, 1): + print_separator() + print(f"📋 Test {scenario_idx}: {scenario['name']}") + print(f"Thread ID: {scenario['thread_id']}") + print_separator() + + thread_id = scenario["thread_id"] + config = {"configurable": {"thread_id": thread_id}} + + all_tool_calls = [] + + for turn_idx, user_message in enumerate(scenario["messages"], 1): + print(f"\n--- Turn {turn_idx} ---") + print_message("user", user_message) + + try: + # Create human message + human_message = HumanMessage(content=user_message) + input_state = {"messages": [human_message]} + + # Get agent response + print("🤖 Agent processing...") + final_state = None + response_text = "" + tool_calls_list = [] + + collected_tool_calls = [] + + for event in agent.stream(input_state, config): + final_state = event + + for node_name, node_output in event.items(): + if not node_output: + continue + messages = node_output.get("messages", []) + + # Collect tool calls directly from node output + node_tool_calls = node_output.get("tool_calls") or [] + if node_tool_calls: + collected_tool_calls.extend(node_tool_calls) + all_tool_calls.extend(node_tool_calls) + + # Collect tool calls embedded in messages + for msg in messages: + if getattr(msg, "tool_calls", None): + for tc in msg.tool_calls: + tool_call_entry = { + "tool": tc.get("name", "unknown"), + "args": tc.get("args", {}), + } + collected_tool_calls.append(tool_call_entry) + all_tool_calls.append(tool_call_entry) + + # Extract response and tool calls + if final_state: + for node_name, node_output in final_state.items(): + if not node_output: + continue + messages = node_output.get("messages", []) + if messages: + # Find the last AI message (non-tool-call message) + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content: + has_tool_calls = ( + hasattr(msg, "tool_calls") and + msg.tool_calls + ) + if not has_tool_calls: + response_text = msg.content + break + + # Use collected tool calls for this turn + if collected_tool_calls: + tool_calls_list = collected_tool_calls + + if not response_text: + response_text = "No response generated." + + print_message( + "assistant", + response_text, + tool_calls_list if tool_calls_list else None, + ) + + # Small delay between turns + time.sleep(0.5) + + except Exception as e: + print(f"✗ Error in turn {turn_idx}: {e}") + import traceback + + traceback.print_exc() + break + + # Verify expected tools were used + if scenario.get("expected_tools"): + used_tools = [tc.get("tool") for tc in all_tool_calls] + expected = scenario["expected_tools"] + found = [tool for tool in expected if tool in used_tools] + print("\n✓ Tools Verification:") + print(f" Expected: {expected}") + print(f" Found: {found}") + if len(found) == len(expected): + print(" ✅ All expected tools were used!") + else: + print(" ⚠️ Some expected tools were not used") + + print(f"\n✓ Test '{scenario['name']}' completed") + time.sleep(1) + + print_separator() + print("✅ All test scenarios completed!") + print_separator() + + +def run_conversation(): # noqa: C901 + """Run interactive conversation loop with the agent.""" + print("\n" + "=" * 60) + print("Customer Support Agent - Interactive Mode") + print("=" * 60) + print("Type 'quit' or 'exit' to end the conversation.\n") + + # Load environment and initialize + load_environment() + initialize_database() + + # Get the agent + try: + agent = get_agent() + except Exception as e: + print(f"Error creating agent: {e}") + sys.exit(1) + + # Conversation thread ID for state management + thread_id = "customer_support_session" + config = {"configurable": {"thread_id": thread_id}} + + print("Agent ready! How can I help you today?\n") + + while True: + try: + # Get user input + user_input = input("You: ").strip() + + if not user_input: + continue + + if user_input.lower() in ["quit", "exit", "q"]: + print( + "\nThank you for contacting customer support. " + "Have a great day!" + ) + break + + # Create human message + human_message = HumanMessage(content=user_input) + + # Invoke agent with the new message + # LangGraph will maintain state via checkpointer + input_state = {"messages": [human_message]} + + # Stream the response + print("\nAgent: ", end="", flush=True) + + final_state = None + tool_calls_list = [] + + for event in agent.stream(input_state, config): + final_state = event + + # Extract and print the final response + response_text = "" + if final_state: + for node_name, node_output in final_state.items(): + messages = node_output.get("messages", []) + if messages: + # Find the last AI message (non-tool-call message) + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content: + if not (hasattr(msg, "tool_calls") and msg.tool_calls): + response_text = msg.content + break + + # Extract tool calls + if "tool_calls" in node_output: + tool_calls_list = node_output.get("tool_calls", []) + + if response_text: + print(response_text) + if tool_calls_list: + tools_used = ', '.join( + tc.get('tool', 'unknown') for tc in tool_calls_list + ) + print(f"\n🔧 Tools used: {tools_used}") + else: + print("I apologize, but I couldn't generate a response.") + + print() # New line after response + + except KeyboardInterrupt: + print("\n\nConversation interrupted. Goodbye!") + break + except Exception as e: + print(f"\nError: {e}") + import traceback + + traceback.print_exc() + print("Please try again or type 'quit' to exit.\n") + + +def run_single_query(query: str): + """Run a single query and return the response.""" + load_environment() + initialize_database() + + try: + agent = get_agent() + except Exception as e: + print(f"Error creating agent: {e}") + sys.exit(1) + + thread_id = "single_query_session" + config = {"configurable": {"thread_id": thread_id}} + + human_message = HumanMessage(content=query) + input_state = {"messages": [human_message]} + + # Get final response + final_state = None + for event in agent.stream(input_state, config): + final_state = event + + # Extract response from final state + if final_state: + for node_name, node_output in final_state.items(): + messages = node_output.get("messages", []) + if messages: + # Find the last AI message (non-tool-call message) + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content: + if not (hasattr(msg, "tool_calls") and msg.tool_calls): + return msg.content + + return "No response generated." + + +if __name__ == "__main__": + # Check command line arguments + if len(sys.argv) > 1: + if sys.argv[1] == "test": + # Run automated tests + run_test_conversations() + else: + # Single query mode + query = " ".join(sys.argv[1:]) + response = run_single_query(query) + print(response) + else: + # Run interactive mode + run_conversation() diff --git a/examples/customer_support_agent/src/mcp_server/__init__.py b/examples/customer_support_agent/src/mcp_server/__init__.py new file mode 100644 index 0000000..154ecb6 --- /dev/null +++ b/examples/customer_support_agent/src/mcp_server/__init__.py @@ -0,0 +1 @@ +"""MCP Server for database operations.""" diff --git a/examples/customer_support_agent/src/mcp_server/server.py b/examples/customer_support_agent/src/mcp_server/server.py new file mode 100644 index 0000000..61cd08a --- /dev/null +++ b/examples/customer_support_agent/src/mcp_server/server.py @@ -0,0 +1,548 @@ +"""FastMCP server with database tools.""" + +import os +import sys +from pathlib import Path +from typing import Optional + +# Add parent directory to path to import agent modules +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +from fastmcp import FastMCP # noqa: E402 +from dotenv import load_dotenv # noqa: E402 + +# Load environment variables +load_dotenv() + +import warnings # noqa: E402 +# Suppress websockets deprecation warnings +warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="websockets" +) +warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + module="uvicorn.protocols.websockets", +) + +# Import database modules +from src.agent.database.setup import get_session, init_database # noqa: E402 +from src.agent.database.queries import ( # noqa: E402 + # Knowledge base + search_knowledge_base, + get_knowledge_base_by_category, + # Orders + get_order_by_number, + get_orders_by_customer, + # Customers + get_customer_by_email, + get_customer_by_id, + get_customer_by_phone, + create_customer as db_create_customer, + # Tickets + create_ticket as db_create_ticket, + update_ticket as db_update_ticket, + get_ticket_by_id, +) +from src.agent.database.models import TicketStatus, TicketPriority # noqa: E402 +from src.utils.logger import get_logger # noqa: E402 + +logger = get_logger(__name__) + +# Initialize FastMCP server +mcp = FastMCP( + name="customer-support-db-server", + instructions="Database operations server for customer support agent", +) + + +# ============ Knowledge Base Tools ============ + + +@mcp.tool() +def search_knowledge_base_tool(query: str, category: Optional[str] = None) -> str: + """ + Search the knowledge base for information relevant to the customer's query. + + Args: + query: The search query or keywords to search for + category: Optional category to filter results (e.g., 'policies', 'shipping', 'account') # noqa: E501 + + Returns: + A formatted string with relevant knowledge base articles, or a message if no results found. # noqa: E501 + """ + logger.info(f"Searching knowledge base - query: {query}, category: {category}") + db = get_session() + try: + articles = search_knowledge_base(db, query, category, limit=5) + + if not articles: + logger.warning(f"No knowledge base articles found for query: {query}") + return f"No knowledge base articles found for query: '{query}'" + + logger.info(f"Found {len(articles)} knowledge base articles") + result = f"Found {len(articles)} relevant article(s):\n\n" + for i, article in enumerate(articles, 1): + result += f"{i}. **{article.title}**\n" + result += f" Category: {article.category or 'Uncategorized'}\n" + result += f" Content: {article.content[:200]}...\n" + if article.tags: + result += f" Tags: {article.tags}\n" + result += "\n" + + return result + except Exception as e: + logger.error(f"Error searching knowledge base: {e}", exc_info=True) + return f"Error searching knowledge base: {str(e)}" + finally: + db.close() + + +@mcp.tool() +def get_knowledge_base_by_category_tool(category: str) -> str: + """ + Get all knowledge base articles in a specific category. + + Args: + category: The category to retrieve articles from (e.g., 'policies', 'shipping', 'account', 'orders', 'billing', 'support') # noqa: E501 + + Returns: + A formatted string with all articles in the category. + """ + db = get_session() + try: + articles = get_knowledge_base_by_category(db, category) + + if not articles: + return f"No articles found in category: '{category}'" + + result = f"Found {len(articles)} article(s) in category '{category}':\n\n" + for i, article in enumerate(articles, 1): + result += f"{i}. **{article.title}**\n" + result += f" {article.content[:300]}...\n\n" + + return result + except Exception as e: + return f"Error retrieving knowledge base articles: {str(e)}" + finally: + db.close() + + +# ============ Order Tools ============ + + +@mcp.tool() +def lookup_order_tool(order_number: str, customer_id: Optional[int] = None) -> str: + """ + Look up order details by order number. + + Args: + order_number: The order number to look up (e.g., 'ORD-001') + customer_id: Optional customer ID to verify the order belongs to the customer + + Returns: + A formatted string with order details, or an error message if not found. + """ + db = get_session() + try: + order = get_order_by_number(db, order_number) + if not order: + return f"Error: Order '{order_number}' not found." + + # Verify customer if provided + if customer_id and order.customer_id != customer_id: + return f"Error: Order '{order_number}' does not belong to customer {customer_id}." # noqa: E501 + + customer = get_customer_by_id(db, order.customer_id) + customer_name = customer.name if customer else "Unknown" + + return ( + f"Order Details:\n" + f"Order Number: {order.order_number}\n" + f"Status: {order.status.value}\n" + f"Total: ${order.total:.2f}\n" + f"Customer: {customer_name} (ID: {order.customer_id})\n" + f"Order Date: {order.created_at}" + ) + except Exception as e: + return f"Error looking up order: {str(e)}" + finally: + db.close() + + +@mcp.tool() +def get_order_status_tool(order_number: str) -> str: + """ + Get the current status of an order. + + Args: + order_number: The order number to check + + Returns: + The current status of the order. + """ + db = get_session() + try: + order = get_order_by_number(db, order_number) + if not order: + return f"Error: Order '{order_number}' not found." + + return ( + f"Order {order_number} Status: {order.status.value}\n" + f"Last Updated: {order.created_at}" + ) + except Exception as e: + return f"Error getting order status: {str(e)}" + finally: + db.close() + + +@mcp.tool() +def get_order_history_tool(customer_id: int) -> str: + """ + Get all orders for a specific customer. + + Args: + customer_id: The ID of the customer + + Returns: + A formatted string with all orders for the customer. + """ + db = get_session() + try: + customer = get_customer_by_id(db, customer_id) + if not customer: + return f"Error: Customer with ID {customer_id} not found." + + orders = get_orders_by_customer(db, customer_id) + + if not orders: + return f"No orders found for customer {customer.name} (ID: {customer_id})." + + result = f"Order History for {customer.name}:\n\n" + for i, order in enumerate(orders, 1): + result += ( + f"{i}. Order {order.order_number}\n" + f" Status: {order.status.value}\n" + f" Total: ${order.total:.2f}\n" + f" Date: {order.created_at}\n\n" + ) + + return result + except Exception as e: + return f"Error getting order history: {str(e)}" + finally: + db.close() + + +# ============ Customer Tools ============ + + +@mcp.tool() +def lookup_customer_tool( + email: Optional[str] = None, + phone: Optional[str] = None, + customer_id: Optional[int] = None, +) -> str: + """ + Look up a customer by email, phone number, or customer ID. + + Args: + email: Customer email address + phone: Customer phone number + customer_id: Customer ID + + Returns: + Customer information if found, or an error message. + """ + db = get_session() + try: + customer = None + + if customer_id: + customer = get_customer_by_id(db, customer_id) + elif email: + customer = get_customer_by_email(db, email) + elif phone: + customer = get_customer_by_phone(db, phone) + else: + return "Error: Must provide at least one of: email, phone, or customer_id" + + if not customer: + identifier = customer_id or email or phone + return f"Error: Customer not found with identifier: {identifier}" + + return ( + f"Customer Information:\n" + f"ID: {customer.id}\n" + f"Name: {customer.name}\n" + f"Email: {customer.email}\n" + f"Phone: {customer.phone or 'Not provided'}\n" + f"Member Since: {customer.created_at}" + ) + except Exception as e: + return f"Error looking up customer: {str(e)}" + finally: + db.close() + + +@mcp.tool() +def get_customer_profile_tool(customer_id: int) -> str: + """ + Get full customer profile including order history and tickets. + + Args: + customer_id: The ID of the customer + + Returns: + A comprehensive customer profile with orders and tickets. + """ + db = get_session() + try: + customer = get_customer_by_id(db, customer_id) + if not customer: + return f"Error: Customer with ID {customer_id} not found." + + orders = get_orders_by_customer(db, customer_id) + tickets = customer.tickets + + result = ( + f"Customer Profile:\n" + f"ID: {customer.id}\n" + f"Name: {customer.name}\n" + f"Email: {customer.email}\n" + f"Phone: {customer.phone or 'Not provided'}\n" + f"Member Since: {customer.created_at}\n\n" + ) + + # Add order summary + result += f"Orders: {len(orders)} total\n" + if orders: + result += "Recent Orders:\n" + for order in orders[:5]: # Show last 5 orders + result += f" - {order.order_number}: {order.status.value} (${order.total:.2f})\n" # noqa: E501 + + result += "\n" + + # Add ticket summary + result += f"Support Tickets: {len(tickets)} total\n" + if tickets: + result += "Recent Tickets:\n" + for ticket in tickets[:5]: # Show last 5 tickets + result += f" - Ticket #{ticket.id}: {ticket.subject} ({ticket.status.value})\n" # noqa: E501 + + return result + except Exception as e: + return f"Error retrieving customer profile: {str(e)}" + finally: + db.close() + + +@mcp.tool() +def create_customer_tool(name: str, email: str, phone: Optional[str] = None) -> str: + """ + Create a new customer record. + + Args: + name: Customer's full name + email: Customer's email address + phone: Optional phone number + + Returns: + A confirmation message with the new customer ID and details. + """ + db = get_session() + try: + # Check if customer already exists + existing = get_customer_by_email(db, email) + if existing: + return f"Error: Customer with email '{email}' already exists (ID: {existing.id})" # noqa: E501 + + customer = db_create_customer(db, name, email, phone) + + return ( + f"Customer created successfully!\n" + f"ID: {customer.id}\n" + f"Name: {customer.name}\n" + f"Email: {customer.email}\n" + f"Phone: {customer.phone or 'Not provided'}" + ) + except Exception as e: + return f"Error creating customer: {str(e)}" + finally: + db.close() + + +# ============ Ticket Tools ============ + + +@mcp.tool() +def create_ticket_tool( + customer_id: int, + subject: str, + description: str, + priority: str = "medium", + order_id: Optional[int] = None, +) -> str: + """ + Create a new support ticket for a customer. + + Args: + customer_id: The ID of the customer creating the ticket + subject: Brief subject line for the ticket + description: Detailed description of the issue + priority: Priority level - 'low', 'medium', 'high', or 'urgent' (default: 'medium') # noqa: E501 + order_id: Optional order ID if the ticket is related to a specific order + + Returns: + A confirmation message with the ticket ID and details. + """ + db = get_session() + try: + # Validate customer exists + customer = get_customer_by_id(db, customer_id) + if not customer: + return f"Error: Customer with ID {customer_id} not found." + + # Map priority string to enum + priority_map = { + "low": TicketPriority.LOW, + "medium": TicketPriority.MEDIUM, + "high": TicketPriority.HIGH, + "urgent": TicketPriority.URGENT, + } + priority_enum = priority_map.get(priority.lower(), TicketPriority.MEDIUM) + + ticket = ( + db_create_ticket(db, customer_id, subject, description, priority_enum, order_id)) # noqa: E501 + + return ( + f"Support ticket created successfully!\n" + f"Ticket ID: {ticket.id}\n" + f"Subject: {ticket.subject}\n" + f"Priority: {ticket.priority.value}\n" + f"Status: {ticket.status.value}\n" + f"Created: {ticket.created_at}" + ) + except Exception as e: + return f"Error creating ticket: {str(e)}" + finally: + db.close() + + +@mcp.tool() +def update_ticket_tool( + ticket_id: int, status: Optional[str] = None, notes: Optional[str] = None +) -> str: + """ + Update an existing support ticket. + + Args: + ticket_id: The ID of the ticket to update + status: New status - 'open', 'in_progress', 'resolved', or 'closed' + notes: Additional notes or updates to add to the ticket + + Returns: + A confirmation message with updated ticket details. + """ + db = get_session() + try: + ticket = get_ticket_by_id(db, ticket_id) + if not ticket: + return f"Error: Ticket with ID {ticket_id} not found." + + # Map status string to enum + status_enum = None + if status: + status_map = { + "open": TicketStatus.OPEN, + "in_progress": TicketStatus.IN_PROGRESS, + "resolved": TicketStatus.RESOLVED, + "closed": TicketStatus.CLOSED, + } + status_enum = status_map.get(status.lower()) + if not status_enum: + return f"Error: Invalid status '{status}'. Valid values: open, in_progress, resolved, closed" # noqa: E501 + + updated_ticket = db_update_ticket(db, ticket_id, status_enum, notes) + + if not updated_ticket: + return f"Error: Failed to update ticket {ticket_id}" + + return ( + f"Ticket updated successfully!\n" + f"Ticket ID: {updated_ticket.id}\n" + f"Subject: {updated_ticket.subject}\n" + f"Status: {updated_ticket.status.value}\n" + f"Priority: {updated_ticket.priority.value}\n" + f"Last Updated: {updated_ticket.updated_at}" + ) + except Exception as e: + return f"Error updating ticket: {str(e)}" + finally: + db.close() + + +@mcp.tool() +def get_ticket_tool(ticket_id: int) -> str: + """ + Retrieve details of a specific support ticket. + + Args: + ticket_id: The ID of the ticket to retrieve + + Returns: + A formatted string with ticket details. + """ + db = get_session() + try: + ticket = get_ticket_by_id(db, ticket_id) + if not ticket: + return f"Error: Ticket with ID {ticket_id} not found." + + customer = get_customer_by_id(db, ticket.customer_id) + customer_name = customer.name if customer else "Unknown" + + result = ( + f"Ticket Details:\n" + f"ID: {ticket.id}\n" + f"Subject: {ticket.subject}\n" + f"Description: {ticket.description}\n" + f"Status: {ticket.status.value}\n" + f"Priority: {ticket.priority.value}\n" + f"Customer: {customer_name} (ID: {ticket.customer_id})\n" + ) + + if ticket.order_id: + result += f"Related Order ID: {ticket.order_id}\n" + + result += f"Created: {ticket.created_at}\n" f"Last Updated: {ticket.updated_at}" + + return result + except Exception as e: + return f"Error retrieving ticket: {str(e)}" + finally: + db.close() + + +if __name__ == "__main__": + # Initialize database before starting server + logger.info("Initializing database...") + try: + init_database(seed_data=True) + logger.info("Database initialized successfully") + except Exception as e: + logger.error(f"Database initialization failed: {e}", exc_info=True) + logger.warning("Server will continue, but database operations may fail") + + host = os.getenv("MCP_SERVER_HOST", "0.0.0.0") + port = int(os.getenv("MCP_SERVER_PORT", "9000")) + logger.info(f"Starting MCP server on {host}:{port}") + logger.info("Using streamable-http transport for HTTP access") + logger.info(f"Server will be available at: http://{host}:{port}/mcp/") + try: + mcp.run(transport="streamable-http", host=host, port=port) + except Exception as e: + logger.error(f"Error starting MCP server: {e}", exc_info=True) + raise diff --git a/examples/customer_support_agent/src/ui/app.py b/examples/customer_support_agent/src/ui/app.py new file mode 100644 index 0000000..e033241 --- /dev/null +++ b/examples/customer_support_agent/src/ui/app.py @@ -0,0 +1,443 @@ +"""Streamlit UI for Customer Support Agent.""" + +import streamlit as st +import sys +import os +from pathlib import Path + +# Add project root to path (go up from src/ui/app.py to project root) +current_file = Path(__file__).resolve() +project_root = current_file.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# Load environment variables from .env file +from dotenv import load_dotenv # noqa: E402 + +env_path = project_root / ".env" +if env_path.exists(): + load_dotenv(env_path) +else: + load_dotenv() # Try default locations + +from langchain_core.messages import HumanMessage # noqa: E402 +from src.agent.graph import get_agent # noqa: E402 +from src.utils.logger import get_logger # noqa: E402 + +# Initialize logger +logger = get_logger(__name__) + +# Page config +st.set_page_config(page_title="Customer Support Agent", page_icon="💬", layout="wide") + +# Initialize session state +if "conversations" not in st.session_state: + st.session_state.conversations = {} +if "current_thread" not in st.session_state: + st.session_state.current_thread = None +if "agent" not in st.session_state: + try: + # Verify Highflame API keys are loaded + logger.info("Checking Highflame API keys") + + highflame_key = os.getenv("HIGHFLAME_API_KEY") + llm_api_key = os.getenv("LLM_API_KEY") + route = os.getenv("HIGHFLAME_ROUTE") + model = os.getenv("MODEL") + + if not highflame_key: + st.error("HIGHFLAME_API_KEY not found. Please check your .env file.") + st.info(f"Looking for .env at: {env_path}") + st.stop() + + if not llm_api_key: + st.error("LLM_API_KEY not found. Please check your .env file.") + st.info(f"Looking for .env at: {env_path}") + st.stop() + + if not route: + st.error("HIGHFLAME_ROUTE not found. Please check your .env file.") + st.info(f"Looking for .env at: {env_path}") + st.stop() + + if not model: + st.error("MODEL not found. Please check your .env file.") + st.info(f"Looking for .env at: {env_path}") + st.stop() + + logger.info(f"Highflame configured with route: {route}, model: {model}") + + st.session_state.agent = get_agent() + logger.info("Agent initialized successfully in Streamlit") + except ValueError as e: + st.error(f"Configuration error: {e}") + st.info("Please ensure your .env file contains:") + st.info(" - HIGHFLAME_API_KEY") + st.info(" - HIGHFLAME_ROUTE (openai or google)") + st.info(" - MODEL (e.g., gpt-4o-mini or gemini-2.5-flash-lite)") + st.info(" - LLM_API_KEY (OpenAI key for openai route, Gemini key for google route)") # noqa: E501 + logger.error(f"Configuration error during Streamlit agent initialization: {e}", exc_info=True) # noqa: E501 + st.stop() + except Exception as e: + st.error(f"Failed to initialize agent: {e}") + logger.critical(f"Critical error initializing agent in Streamlit: {e}", exc_info=True) # noqa: E501 + st.stop() + +# Sidebar for conversations +with st.sidebar: + st.title("💬 Conversations") + + # New conversation button + if st.button("➕ New Conversation", use_container_width=True): + import uuid + + new_thread = str(uuid.uuid4()) + st.session_state.conversations[new_thread] = [] + st.session_state.current_thread = new_thread + st.rerun() + + st.divider() + + # List of conversations + st.subheader("Recent") + thread_names = list(st.session_state.conversations.keys()) + + if not thread_names: + st.caption("No conversations yet. Start a new one!") + else: + for thread_id in thread_names: + # Get conversation title (first message or thread ID) + title = f"Thread {thread_id[:8]}" + if st.session_state.conversations[thread_id]: + first_msg = st.session_state.conversations[thread_id][0] + if isinstance(first_msg, dict) and first_msg.get("role") == "user": + title = first_msg.get("content", title)[:30] + "..." + + # Select conversation + if st.button( + title, + key=f"thread_{thread_id}", + use_container_width=True, + type="primary" if thread_id == st.session_state.current_thread else "secondary", # noqa: E501 + ): + st.session_state.current_thread = thread_id + st.rerun() + + st.divider() + st.caption("💡 Tip: Each conversation maintains its own context and memory.") + +# Main chat area +st.title("Customer Support Agent") + +# Initialize current conversation if needed +if st.session_state.current_thread is None: + import uuid + + st.session_state.current_thread = str(uuid.uuid4()) + st.session_state.conversations[st.session_state.current_thread] = [] + +current_thread = st.session_state.current_thread +if current_thread not in st.session_state.conversations: + st.session_state.conversations[current_thread] = [] + +# Display conversation history +chat_container = st.container() +with chat_container: + for message in st.session_state.conversations[current_thread]: + role = message.get("role", "user") + content = message.get("content", "") + + if role == "user": + with st.chat_message("user"): + st.write(content) + else: + with st.chat_message("assistant"): + st.write(content) + + # Show tool calls if available + if message and isinstance(message, dict) and "tool_calls" in message and message["tool_calls"]: # noqa: E501 + with st.expander("🔧 Tools Used"): + for tool_call in message["tool_calls"]: + st.code(f"{tool_call.get('tool', 'unknown')}") + +# Chat input +user_input = st.chat_input("Type your message here...") + +if user_input: # noqa: C901 + # Add user message to conversation + st.session_state.conversations[current_thread].append({"role": "user", "content": user_input}) # noqa: E501 + + # Display user message immediately + with st.chat_message("user"): + st.write(user_input) + + # Get agent response + with st.chat_message("assistant"): + with st.spinner("Thinking..."): + try: + logger.info(f"Processing user message in thread: {current_thread}") + logger.debug(f"User message: {user_input[:200]}") + + agent = st.session_state.agent + human_message = HumanMessage(content=user_input) + + input_state = {"messages": [human_message]} + config = {"configurable": {"thread_id": current_thread}} + + # Stream response and track all events for debugging + final_state = None + response_text = "" + tool_calls_list = [] + debug_events = [] # Track all events for debug view + event_count = 0 + + logger.debug("Starting agent stream") + for event in agent.stream(input_state, config): + final_state = event + event_count += 1 + logger.debug(f"Agent event {event_count}: {list(event.keys())}") + + # Collect debug info from each event + for node_name, node_output in event.items(): + try: + # Fix: Check None first to avoid "argument of type 'NoneType' is not iterable" # noqa: E501 + if node_output is None: + logger.debug(f"Node {node_name} has None output") + debug_events.append( + { + "node": node_name, + "has_messages": False, + "has_tool_calls": False, + "tool_calls": [], + } + ) + continue + + # Now safe to check for keys + has_messages = "messages" in node_output + has_tool_calls = "tool_calls" in node_output + tool_calls = node_output.get("tool_calls", []) + + debug_events.append( + { + "node": node_name, + "has_messages": has_messages, + "has_tool_calls": has_tool_calls, + "tool_calls": tool_calls, + } + ) + except Exception as e: + logger.error(f"Error processing debug event for node {node_name}: {e}", exc_info=True) # noqa: E501 + debug_events.append( + { + "node": node_name, + "has_messages": False, + "has_tool_calls": False, + "tool_calls": [], + "error": str(e), + } + ) + + logger.debug(f"Agent stream completed with {event_count} events") + + # Extract response and tool calls + logger.debug("Extracting response from final state") + if final_state: + try: + # Priority 1: Check synthesize node + if "synthesize" in final_state: + logger.debug("Checking synthesize node for response") + synthesize_output = final_state.get("synthesize") + if synthesize_output: + messages = ( + synthesize_output.get("messages", []) if synthesize_output else []) # noqa: E501 + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content: + tool_calls_attr = ( + getattr(msg, "tool_calls", None)) + if not tool_calls_attr: + response_text = msg.content + logger.debug(f"Found response in synthesize node (length: {len(response_text)})") # noqa: E501 + break + + # Priority 2: Check call_tools node + if not response_text and "call_tools" in final_state: + logger.debug("Checking call_tools node for response") + call_tools_output = final_state.get("call_tools") + if call_tools_output: + messages = ( + call_tools_output.get("messages", []) if call_tools_output else []) # noqa: E501 + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content: + tool_calls_attr = ( + getattr(msg, "tool_calls", None)) + if not tool_calls_attr: + response_text = msg.content + logger.debug(f"Found response in call_tools node (length: {len(response_text)})") # noqa: E501 + break + + # Priority 3: Check all nodes + if not response_text: + logger.debug("Checking all nodes for response") + for node_name, node_output in final_state.items(): + if node_name in ["start", "understand_intent"]: + continue # Skip these nodes + + if node_output is None: + logger.debug(f"Skipping {node_name} - node_output is None") # noqa: E501 + continue + + try: + messages = ( + node_output.get("messages", []) if node_output else []) # noqa: E501 + if messages: + # Find the last AI message (non-tool-call message) # noqa: E501 + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content: + tool_calls_attr = ( + getattr(msg, "tool_calls", None)) + if not tool_calls_attr: + response_text = msg.content + logger.debug(f"Found response in {node_name} node (length: {len(response_text)})") # noqa: E501 + break + if response_text: + break + except Exception as e: + logger.error(f"Error extracting response from node {node_name}: {e}", exc_info=True) # noqa: E501 + continue + + # Extract tool calls from state + logger.debug("Extracting tool calls from final state") + for node_name, node_output in final_state.items(): + if node_output is None: + continue + + try: + if "tool_calls" in node_output: + tool_calls = node_output.get("tool_calls", []) + if isinstance(tool_calls, list): + tool_calls_list.extend(tool_calls) + logger.debug(f"Found {len(tool_calls)} tool calls in {node_name}") # noqa: E501 + + # Also check messages for tool calls + messages = ( + node_output.get("messages", []) if node_output else []) # noqa: E501 + if messages: + for msg in messages: + if hasattr(msg, "tool_calls"): + tool_calls_attr = ( + getattr(msg, "tool_calls", None)) + if tool_calls_attr and isinstance(tool_calls_attr, list): # noqa: E501 + for tc in tool_calls_attr: + if isinstance(tc, dict): + tool_calls_list.append( + { + "tool": tc.get("name", "unknown"), # noqa: E501 + "args": tc.get("args", {}), # noqa: E501 + "id": tc.get("id", ""), + } + ) + except Exception as e: + logger.error(f"Error extracting tool calls from node {node_name}: {e}", exc_info=True) # noqa: E501 + continue + except Exception as e: + logger.error(f"Error extracting response from final_state: {e}", exc_info=True) # noqa: E501 + response_text = f"Error processing response: {str(e)}" + + if not response_text: + logger.warning("No response text extracted from agent") + response_text = ( + "I apologize, but I couldn't generate a response. Please try again." # noqa: E501 + ) + else: + logger.info(f"Response extracted successfully (length: {len(response_text)})") # noqa: E501 + logger.debug(f"Response preview: {response_text[:200]}") + + # Display response + st.write(response_text) + + # Show tool calls with detailed info + if tool_calls_list: + with st.expander(f"🔧 Tools Used ({len(tool_calls_list)})"): + for i, tool_call in enumerate(tool_calls_list, 1): + tool_name = tool_call.get("tool", "unknown") + tool_args = tool_call.get("args", {}) + st.markdown(f"**{i}. {tool_name}**") + if tool_args: + st.json(tool_args) + else: + st.caption("ℹ️ No tools were used for this response") + + # Debug view showing conversation flow + with st.expander("🐛 Debug View (Conversation Flow)"): + st.markdown("### Event Flow") + for i, event in enumerate(debug_events, 1): + st.markdown(f"**{i}. {event['node']}**") + st.text(f" Messages: {event.get('message_count', 0)}") + if event.get("tool_calls"): + st.text(f" Tool calls: {len(event['tool_calls'])}") + for tc in event["tool_calls"]: + st.code(f" - {tc.get('tool', 'unknown')}") + + st.markdown("### Memory State") + st.text(f"Thread ID: {current_thread}") + st.text( + f"Messages in UI state: {len(st.session_state.conversations[current_thread])}" # noqa: E501 + ) + st.text("✓ Memory maintained via LangGraph checkpointer (thread_id)") # noqa: E501 + st.caption( + "Note: LangGraph's MemorySaver maintains conversation history internally using the thread_id" # noqa: E501 + ) + + # Debug view + with st.expander("🐛 Debug View (Conversation Flow)"): + st.markdown("### Event Flow") + for i, event in enumerate(debug_events, 1): + st.markdown(f"**{i}. {event['node']}**") + if event["has_tool_calls"] and event["tool_calls"]: + st.code(f"Tool calls: {event['tool_calls']}") + + st.markdown("### Full State") + if final_state: + for node_name, node_output in final_state.items(): + st.markdown(f"**{node_name}:**") + try: + if node_output is not None: + if "messages" in node_output: + msg_count = len(node_output["messages"]) + st.text(f" Messages: {msg_count}") + if "tool_calls" in node_output: + st.text(f" Tool calls: {len(node_output['tool_calls'])}") # noqa: E501 + else: + st.text(" (No output)") + except Exception as e: + logger.error(f"Error displaying state for node {node_name}: {e}", exc_info=True) # noqa: E501 + st.text(f" Error: {str(e)}") + + st.markdown("### Memory Check") + st.text(f"Thread ID: {current_thread}") + msg_count = len( + st.session_state.conversations[current_thread] + ) + st.text(f"Conversation messages in UI: {msg_count}") + + # Add assistant response to conversation + st.session_state.conversations[current_thread].append( + { + "role": "assistant", + "content": response_text, + "tool_calls": ( + tool_calls_list if tool_calls_list else None + ), + } + ) + + except Exception as e: + error_msg = f"Error: {str(e)}" + logger.error(f"Error in Streamlit UI: {e}", exc_info=True) + st.error(error_msg) + st.session_state.conversations[current_thread].append( + {"role": "assistant", "content": error_msg} + ) + + st.rerun() diff --git a/examples/customer_support_agent/src/ui/db_viewer.py b/examples/customer_support_agent/src/ui/db_viewer.py new file mode 100644 index 0000000..58db27c --- /dev/null +++ b/examples/customer_support_agent/src/ui/db_viewer.py @@ -0,0 +1,173 @@ +"""Database viewer and query tool for SQLite database.""" + +import streamlit as st +import sqlite3 +import pandas as pd +import sys +import os +from pathlib import Path + +# Add project root to path (go up from src/ui/db_viewer.py to project root) +current_file = Path(__file__).resolve() +project_root = current_file.parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +from src.agent.database.setup import get_database_path # noqa: E402 + +st.set_page_config(page_title="Database Viewer", page_icon="🗄️", layout="wide") + +st.title("🗄️ Database Viewer") + +# Get database path and resolve to absolute path +db_path_str = get_database_path() +# Resolve relative paths to absolute +if not os.path.isabs(db_path_str): + # Get project root (go up from src/ui/db_viewer.py) + project_root = Path(__file__).resolve().parent.parent.parent + db_path = (project_root / db_path_str.lstrip("./")).resolve() +else: + db_path = Path(db_path_str) + +st.sidebar.info(f"Database: `{db_path}`") + +# Connect to database +try: + # Ensure database file exists + if not db_path.exists(): + st.warning(f"Database file not found at: {db_path}") + st.info("The database will be created when you first use the agent.") + st.stop() + + # Convert Path to string for sqlite3 + conn = sqlite3.connect(str(db_path)) +except Exception as e: + st.error(f"Failed to connect to database: {e}") + st.error(f"Path attempted: {db_path}") + st.stop() + +# Get table list +cursor = conn.cursor() +cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") +tables = [row[0] for row in cursor.fetchall()] + +# Sidebar for table selection +st.sidebar.header("Tables") +selected_table = st.sidebar.selectbox("Select a table", tables) + +if selected_table: + # Display table data + st.subheader(f"Table: `{selected_table}`") + + # Get table info + cursor.execute(f"PRAGMA table_info({selected_table})") + columns = cursor.fetchall() + + col1, col2 = st.columns(2) + with col1: + st.metric("Columns", len(columns)) + + # Get row count + cursor.execute(f"SELECT COUNT(*) FROM {selected_table}") + row_count = cursor.fetchone()[0] + with col2: + st.metric("Rows", row_count) + + # Display columns info + with st.expander("📋 Column Information"): + if columns: + df_columns = pd.DataFrame( + columns, columns=["CID", "Name", "Type", "Not Null", "Default", "PK"] + ) + st.dataframe(df_columns, use_container_width=True) + + # Display data + st.subheader("Data") + query = f"SELECT * FROM {selected_table} LIMIT 100" + df = pd.read_sql_query(query, conn) + st.dataframe(df, use_container_width=True, height=400) + + # Download button + csv = df.to_csv(index=False) + st.download_button( + label="📥 Download as CSV", + data=csv, + file_name=f"{selected_table}.csv", + mime="text/csv", + ) + +# Custom SQL query +st.divider() +st.subheader("🔍 Custom SQL Query") + +query_text = st.text_area( + "Enter SQL query", height=100, placeholder="SELECT * FROM customers LIMIT 10;" +) + +if st.button("Execute Query"): + if query_text.strip(): + try: + result_df = pd.read_sql_query(query_text, conn) + st.dataframe(result_df, use_container_width=True) + st.success(f"Query executed successfully. Returned {len(result_df)} rows.") + except Exception as e: + st.error(f"Query error: {e}") + else: + st.warning("Please enter a SQL query.") + +# Predefined queries +st.divider() +st.subheader("📊 Predefined Queries") + +predefined_queries = { + "All Customers": "SELECT * FROM customers ORDER BY created_at DESC LIMIT 20;", + "All Orders": ( + "SELECT o.*, c.name as customer_name, c.email FROM orders o " + "JOIN customers c ON o.customer_id = c.id " + "ORDER BY o.created_at DESC LIMIT 20;" + ), + "All Tickets": ( + "SELECT t.*, c.name as customer_name, c.email FROM tickets t " + "JOIN customers c ON t.customer_id = c.id " + "ORDER BY t.created_at DESC LIMIT 20;" + ), + "Knowledge Base Articles": "SELECT * FROM knowledge_base ORDER BY created_at DESC;", + "Orders by Status": ( + "SELECT status, COUNT(*) as count FROM orders GROUP BY status;" + ), + "Tickets by Priority": ( + "SELECT priority, COUNT(*) as count FROM tickets GROUP BY priority;" + ), + "Tickets by Status": ( + "SELECT status, COUNT(*) as count FROM tickets GROUP BY status;" + ), + "Customer Order Summary": """ + SELECT + c.id, + c.name, + c.email, + COUNT(o.id) as total_orders, + SUM(o.total) as total_spent, + COUNT(t.id) as total_tickets + FROM customers c + LEFT JOIN orders o ON c.id = o.customer_id + LEFT JOIN tickets t ON c.id = t.customer_id + GROUP BY c.id, c.name, c.email + ORDER BY total_spent DESC + LIMIT 20; + """, +} + +for query_name, query in predefined_queries.items(): + if st.button(f"Run: {query_name}", key=f"predef_{query_name}"): + try: + result_df = pd.read_sql_query(query, conn) + st.dataframe(result_df, use_container_width=True) + st.success( + f"Query '{query_name}' executed. " + f"Returned {len(result_df)} rows." + ) + except Exception as e: + st.error(f"Query error: {e}") + +conn.close() diff --git a/examples/customer_support_agent/src/utils/__init__.py b/examples/customer_support_agent/src/utils/__init__.py new file mode 100644 index 0000000..183c974 --- /dev/null +++ b/examples/customer_support_agent/src/utils/__init__.py @@ -0,0 +1 @@ +"""Utility modules.""" diff --git a/examples/customer_support_agent/src/utils/logger.py b/examples/customer_support_agent/src/utils/logger.py new file mode 100644 index 0000000..6aa89b1 --- /dev/null +++ b/examples/customer_support_agent/src/utils/logger.py @@ -0,0 +1,86 @@ +"""Logging configuration for the application.""" + +import os +import logging +from pathlib import Path +from datetime import datetime +from logging.handlers import RotatingFileHandler + + +def setup_logger(name: str, log_level: str = "INFO") -> logging.Logger: + """ + Set up a logger with file and console handlers. + + Args: + name: Logger name (typically __name__) + log_level: Logging level (DEBUG, INFO, WARNING, ERROR) + + Returns: + Configured logger instance + """ + logger = logging.getLogger(name) + logger.setLevel(getattr(logging, log_level.upper(), logging.INFO)) + + # Avoid duplicate handlers + if logger.handlers: + return logger + + # Create logs directory + project_root = Path(__file__).resolve().parent.parent.parent + logs_dir = project_root / "logs" + logs_dir.mkdir(exist_ok=True) + + # Create timestamped log file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = logs_dir / f"app_{timestamp}.log" + + # File handler with rotation (10MB max, keep 5 backups) + file_handler = RotatingFileHandler( + log_file, + maxBytes=10 * 1024 * 1024, # 10MB + backupCount=5, + encoding="utf-8", + ) + file_handler.setLevel(logging.DEBUG) + + # Console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + + # Formatter + detailed_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - " + "%(filename)s:%(lineno)d - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + simple_formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s", + datefmt="%H:%M:%S", + ) + + file_handler.setFormatter(detailed_formatter) + console_handler.setFormatter(simple_formatter) + + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger + + +def get_logger(name: str) -> logging.Logger: + """ + Get or create a logger instance. + + Args: + name: Logger name (typically __name__) + + Returns: + Logger instance + """ + logger = logging.getLogger(name) + if not logger.handlers: + # Use INFO level by default, can be overridden with LOG_LEVEL env var + log_level = os.getenv("LOG_LEVEL", "INFO") + return setup_logger(name, log_level) + return logger diff --git a/examples/customer_support_agent/start.sh b/examples/customer_support_agent/start.sh new file mode 100755 index 0000000..ad4cb5c --- /dev/null +++ b/examples/customer_support_agent/start.sh @@ -0,0 +1,184 @@ +#!/bin/bash + +# Customer Support Agent Startup Script +# This script starts all required services: MCP Server, FastAPI Server, and Streamlit UI + +set -e # Exit on error + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Get the directory where the script is located +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR" + +echo -e "${BLUE}========================================${NC}" +echo -e "${BLUE}Customer Support Agent Startup${NC}" +echo -e "${BLUE}========================================${NC}" +echo "" + +# Check if .env file exists +if [ ! -f ".env" ]; then + echo -e "${RED}Error: .env file not found!${NC}" + echo -e "${YELLOW}Please create a .env file with the following variables:${NC}" + echo "" + echo "HIGHFLAME_API_KEY=your_highflame_api_key" + echo "HIGHFLAME_ROUTE=google # or 'openai'" + echo "MODEL=gemini-2.5-flash-lite # or 'gpt-4o-mini' for OpenAI" + echo "LLM_API_KEY=your_openai_or_gemini_api_key" + echo "MCP_SERVER_URL=http://localhost:9000/mcp" + echo "MCP_SERVER_HOST=0.0.0.0" + echo "MCP_SERVER_PORT=9000" + echo "DATABASE_PATH=./src/db/support_agent.db" + echo "PORT=8000" + echo "" + exit 1 +fi + +# Check if Python is available +if ! command -v python3 &> /dev/null; then + echo -e "${RED}Error: python3 not found!${NC}" + echo "Please install Python 3.11 or higher" + exit 1 +fi + +# Check if required packages are installed +echo -e "${YELLOW}Checking dependencies...${NC}" +if ! python3 -c "import fastapi, streamlit, langchain, fastmcp" 2>/dev/null; then + echo -e "${YELLOW}Installing dependencies...${NC}" + pip install -r requirements.txt +fi +echo -e "${GREEN}✓ Dependencies checked${NC}" +echo "" + +# Create logs directory if it doesn't exist +mkdir -p logs + +# Function to check if a port is in use +check_port() { + local port=$1 + if lsof -Pi :$port -sTCP:LISTEN -t >/dev/null 2>&1 ; then + return 0 # Port is in use + else + return 1 # Port is free + fi +} + +# Function to kill process on a port +kill_port() { + local port=$1 + local pid=$(lsof -ti :$port) + if [ ! -z "$pid" ]; then + echo -e "${YELLOW}Killing process on port $port (PID: $pid)${NC}" + kill -9 $pid 2>/dev/null || true + sleep 1 + fi +} + +# Check and free ports +echo -e "${YELLOW}Checking ports...${NC}" +if check_port 9000; then + echo -e "${YELLOW}Port 9000 is in use. Attempting to free it...${NC}" + kill_port 9000 +fi + +if check_port 8000; then + echo -e "${YELLOW}Port 8000 is in use. Attempting to free it...${NC}" + kill_port 8000 +fi + +if check_port 8501; then + echo -e "${YELLOW}Port 8501 is in use. Attempting to free it...${NC}" + kill_port 8501 +fi + +echo -e "${GREEN}✓ Ports checked${NC}" +echo "" + +# Function to cleanup on exit +cleanup() { + echo "" + echo -e "${YELLOW}Shutting down services...${NC}" + # Kill background processes + jobs -p | xargs -r kill 2>/dev/null || true + exit 0 +} + +# Set trap to cleanup on script exit +trap cleanup SIGINT SIGTERM EXIT + +# Start MCP Server +echo -e "${BLUE}Starting MCP Server on port 9000...${NC}" +python3 -m src.mcp_server.server > logs/mcp_server.log 2>&1 & +MCP_PID=$! +sleep 3 + +# Check if MCP server started successfully +if ! kill -0 $MCP_PID 2>/dev/null; then + echo -e "${RED}Error: MCP Server failed to start${NC}" + echo "Check logs/mcp_server.log for details" + exit 1 +fi +echo -e "${GREEN}✓ MCP Server started (PID: $MCP_PID)${NC}" +echo "" + +# Start FastAPI Server +echo -e "${BLUE}Starting FastAPI Server on port 8000...${NC}" +python3 -m uvicorn src.api:app --host 0.0.0.0 --port 8000 > logs/api_server.log 2>&1 & +API_PID=$! +sleep 3 + +# Check if API server started successfully +if ! kill -0 $API_PID 2>/dev/null; then + echo -e "${RED}Error: FastAPI Server failed to start${NC}" + echo "Check logs/api_server.log for details" + kill $MCP_PID 2>/dev/null || true + exit 1 +fi +echo -e "${GREEN}✓ FastAPI Server started (PID: $API_PID)${NC}" +echo "" + +# Start Streamlit UI +echo -e "${BLUE}Starting Streamlit UI on port 8501...${NC}" +cd src/ui +streamlit run app.py --server.port 8501 --server.address 0.0.0.0 > ../../logs/streamlit_ui.log 2>&1 & +UI_PID=$! +cd ../.. +sleep 5 + +# Check if Streamlit started successfully +if ! kill -0 $UI_PID 2>/dev/null; then + echo -e "${RED}Error: Streamlit UI failed to start${NC}" + echo "Check logs/streamlit_ui.log for details" + kill $MCP_PID $API_PID 2>/dev/null || true + exit 1 +fi +echo -e "${GREEN}✓ Streamlit UI started (PID: $UI_PID)${NC}" +echo "" + +# Display service URLs +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}All services started successfully!${NC}" +echo -e "${GREEN}========================================${NC}" +echo "" +echo -e "${BLUE}Service URLs:${NC}" +echo -e " ${GREEN}Streamlit UI:${NC} http://localhost:8501" +echo -e " ${GREEN}FastAPI Server:${NC} http://localhost:8000" +echo -e " ${GREEN}API Docs:${NC} http://localhost:8000/docs" +echo -e " ${GREEN}MCP Server:${NC} http://localhost:9000/mcp" +echo "" +echo -e "${BLUE}Log Files:${NC}" +echo -e " MCP Server: logs/mcp_server.log" +echo -e " API Server: logs/api_server.log" +echo -e " Streamlit UI: logs/streamlit_ui.log" +echo "" +echo -e "${YELLOW}Press Ctrl+C to stop all services${NC}" +echo "" + +# Wait for all background processes +wait + diff --git a/examples/customer_support_agent/tests/__init__.py b/examples/customer_support_agent/tests/__init__.py new file mode 100644 index 0000000..2c20de3 --- /dev/null +++ b/examples/customer_support_agent/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the customer support agent.""" diff --git a/examples/customer_support_agent/tests/test_agent.py b/examples/customer_support_agent/tests/test_agent.py new file mode 100644 index 0000000..29d46d2 --- /dev/null +++ b/examples/customer_support_agent/tests/test_agent.py @@ -0,0 +1,258 @@ +"""Tests for the customer support agent.""" + +import pytest +import os +from dotenv import load_dotenv +from src.agent.database.setup import init_database, get_session +from src.agent.database.queries import get_customer_by_email, get_order_by_number +from src.agent.graph import get_agent +from src.agent.llm import get_llm +from langchain_core.messages import HumanMessage + +# Load environment variables +load_dotenv() + + +@pytest.fixture(scope="module") +def setup_database(): + """Initialize database for testing.""" + # Use a test database + os.environ["DATABASE_PATH"] = "./test_support_agent.db" + init_database(seed_data=True) + yield + # Cleanup + if os.path.exists("./test_support_agent.db"): + os.remove("./test_support_agent.db") + + +def test_database_initialization(setup_database): + """Test that database is initialized with mock data.""" + db = get_session() + try: + # Check customers + customer = get_customer_by_email(db, "john.doe@example.com") + assert customer is not None + assert customer.name == "John Doe" + + # Check orders + order = get_order_by_number(db, "ORD-001") + assert order is not None + assert order.total == 99.99 + + finally: + db.close() + + +def test_customer_lookup(setup_database): + """Test customer lookup functionality.""" + db = get_session() + try: + customer = get_customer_by_email(db, "jane.smith@example.com") + assert customer is not None + assert customer.email == "jane.smith@example.com" + finally: + db.close() + + +def test_order_lookup(setup_database): + """Test order lookup functionality.""" + db = get_session() + try: + order = get_order_by_number(db, "ORD-001") + assert order is not None + assert order.order_number == "ORD-001" + finally: + db.close() + + +def test_openai_llm_initialization(): + """Test OpenAI LLM initialization.""" + if not os.getenv("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + try: + llm = get_llm("openai") + assert llm is not None + assert llm.model_name == os.getenv("OPENAI_MODEL", "gpt-4o-mini") + except Exception as e: + pytest.fail(f"OpenAI LLM initialization failed: {e}") + + +def test_gemini_llm_initialization(): + """Test Gemini LLM initialization.""" + if not os.getenv("GEMINI_API_KEY"): + pytest.skip("GEMINI_API_KEY not set") + + try: + llm = get_llm("gemini") + assert llm is not None + except Exception as e: + pytest.fail(f"Gemini LLM initialization failed: {e}") + + +def test_llm_provider_env_var(): + """Test LLM provider selection via environment variable.""" + original_provider = os.getenv("LLM_PROVIDER") + + # Test OpenAI + if os.getenv("OPENAI_API_KEY"): + os.environ["LLM_PROVIDER"] = "openai" + llm = get_llm() + assert llm is not None + + # Test Gemini + if os.getenv("GEMINI_API_KEY"): + os.environ["LLM_PROVIDER"] = "gemini" + llm = get_llm() + assert llm is not None + + # Restore original + if original_provider: + os.environ["LLM_PROVIDER"] = original_provider + + +def test_agent_initialization(): + """Test that agent can be initialized.""" + # Skip if no API key + if not os.getenv("OPENAI_API_KEY") and not os.getenv("GEMINI_API_KEY"): + pytest.skip("No LLM API key set") + + try: + agent = get_agent() + assert agent is not None + except Exception as e: + pytest.skip(f"Agent initialization failed: {e}") + + +def test_agent_memory(): + """Test that agent maintains conversation memory.""" + # Skip if no API key + if not os.getenv("OPENAI_API_KEY") and not os.getenv("GEMINI_API_KEY"): + pytest.skip("No LLM API key set") + + try: + agent = get_agent() + thread_id = "test-memory-thread" + config = {"configurable": {"thread_id": thread_id}} + + # First message + input_state1 = {"messages": [HumanMessage(content="My name is Test User")]} + + # Process first message + for event in agent.stream(input_state1, config): + pass + + # Second message - should remember context + input_state2 = {"messages": [HumanMessage(content="What is my name?")]} + + # Process second message + final_state = None + for event in agent.stream(input_state2, config): + final_state = event + + # Verify state was maintained + assert final_state is not None + + except Exception as e: + pytest.skip(f"Memory test failed: {e}") + + +def test_mcp_tools_import(): + """Test that MCP tools can be imported and created.""" + try: + from src.agent.mcp_tools import create_mcp_tools + + tools = create_mcp_tools() + assert len(tools) == 11 + + # Verify tool names + tool_names = [tool.name for tool in tools] + assert "search_knowledge_base_tool" in tool_names + assert "lookup_order_tool" in tool_names + assert "lookup_customer_tool" in tool_names + assert "create_ticket_tool" in tool_names + except Exception as e: + pytest.fail(f"MCP tools import failed: {e}") + + +def test_direct_tools_import(): + """Test that direct tools can be imported.""" + try: + from src.agent.tools.web_search import web_search_tool, web_search_news_tool + from src.agent.tools.email import send_email_tool + + assert web_search_tool is not None + assert web_search_news_tool is not None + assert send_email_tool is not None + except Exception as e: + pytest.fail(f"Direct tools import failed: {e}") + + +def test_agent_with_openai(setup_database): + """Test agent with OpenAI provider.""" + if not os.getenv("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + # Set provider + original_provider = os.getenv("LLM_PROVIDER") + os.environ["LLM_PROVIDER"] = "openai" + + try: + agent = get_agent() + config = {"configurable": {"thread_id": "test-openai"}} + input_state = {"messages": [HumanMessage(content="Hello")]} + + final_state = None + for event in agent.stream(input_state, config): + final_state = event + + assert final_state is not None + except Exception as e: + pytest.fail(f"OpenAI agent test failed: {e}") + finally: + if original_provider: + os.environ["LLM_PROVIDER"] = original_provider + + +def test_agent_with_gemini(setup_database): + """Test agent with Gemini provider.""" + if not os.getenv("GEMINI_API_KEY"): + pytest.skip("GEMINI_API_KEY not set") + + # Set provider + original_provider = os.getenv("LLM_PROVIDER") + os.environ["LLM_PROVIDER"] = "gemini" + + try: + agent = get_agent() + config = {"configurable": {"thread_id": "test-gemini"}} + input_state = {"messages": [HumanMessage(content="Hello")]} + + final_state = None + for event in agent.stream(input_state, config): + final_state = event + + assert final_state is not None + except Exception as e: + pytest.fail(f"Gemini agent test failed: {e}") + finally: + if original_provider: + os.environ["LLM_PROVIDER"] = original_provider + + +def test_knowledge_base_search(setup_database): + """Test knowledge base search via MCP tools.""" + # Skip test if MCP server is not running + pytest.skip("MCP server tests require MCP server to be running") + + +def test_order_tool(setup_database): + """Test order lookup tool via MCP.""" + # Skip test if MCP server is not running + pytest.skip("MCP server tests require MCP server to be running") + + +def test_customer_tool(setup_database): + """Test customer lookup tool via MCP.""" + # Skip test if MCP server is not running + pytest.skip("MCP server tests require MCP server to be running") diff --git a/examples/customer_support_agent/tests/test_mcp_server.py b/examples/customer_support_agent/tests/test_mcp_server.py new file mode 100644 index 0000000..5098c65 --- /dev/null +++ b/examples/customer_support_agent/tests/test_mcp_server.py @@ -0,0 +1,133 @@ +"""Tests for MCP server functionality.""" + +import pytest +import os +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + + +@pytest.fixture(scope="module") +def setup_database(): + """Initialize database for testing.""" + from src.agent.database.setup import init_database + + # Use a test database + os.environ["DATABASE_PATH"] = "./test_support_agent.db" + init_database(seed_data=True) + yield + # Cleanup + if os.path.exists("./test_support_agent.db"): + os.remove("./test_support_agent.db") + + +@pytest.mark.asyncio +async def test_mcp_client_connection(): + """Test MCP client can connect to server.""" + # Skip if MCP_SERVER_URL not set + if not os.getenv("MCP_SERVER_URL"): + pytest.skip("MCP_SERVER_URL not set") + + try: + from src.agent.mcp_tools import get_mcp_client + + client = get_mcp_client() + + async with client: + # Test connection by listing tools + tools = await client.list_tools() + assert tools is not None + assert len(tools) > 0 + except Exception as e: + pytest.skip(f"MCP server not available: {e}") + + +@pytest.mark.asyncio +async def test_mcp_server_tools_list(): + """Test MCP server exposes all expected tools.""" + if not os.getenv("MCP_SERVER_URL"): + pytest.skip("MCP_SERVER_URL not set") + + try: + from src.agent.mcp_tools import get_mcp_client + + client = get_mcp_client() + + async with client: + tools = await client.list_tools() + tool_names = [tool.name for tool in tools] + + # Verify all 11 tools are present + assert "search_knowledge_base_tool" in tool_names + assert "get_knowledge_base_by_category_tool" in tool_names + assert "lookup_order_tool" in tool_names + assert "get_order_status_tool" in tool_names + assert "get_order_history_tool" in tool_names + assert "lookup_customer_tool" in tool_names + assert "get_customer_profile_tool" in tool_names + assert "create_customer_tool" in tool_names + assert "create_ticket_tool" in tool_names + assert "update_ticket_tool" in tool_names + assert "get_ticket_tool" in tool_names + + assert len(tool_names) == 11 + except Exception as e: + pytest.skip(f"MCP server not available: {e}") + + +@pytest.mark.asyncio +async def test_mcp_tool_invocation(): + """Test calling an MCP tool.""" + if not os.getenv("MCP_SERVER_URL"): + pytest.skip("MCP_SERVER_URL not set") + + try: + from src.agent.mcp_tools import call_mcp_tool + + # Test order lookup + result = call_mcp_tool("get_order_status_tool", order_number="ORD-001") + assert result is not None + assert "ORD-001" in result + except Exception as e: + pytest.skip(f"MCP tool invocation failed: {e}") + + +def test_mcp_tool_creation(): + """Test MCP tools can be created.""" + from src.agent.mcp_tools import create_mcp_tools + + tools = create_mcp_tools() + assert len(tools) == 11 + + # Check that expected tools are present + tool_names = [tool.name for tool in tools] + assert "search_knowledge_base_tool" in tool_names + assert "lookup_order_tool" in tool_names + assert "lookup_customer_tool" in tool_names + assert "create_ticket_tool" in tool_names + + +def test_agent_initialization_with_mcp(): + """Test that agent can be initialized with MCP tools.""" + # Skip if no API key + if not os.getenv("OPENAI_API_KEY") and not os.getenv("GEMINI_API_KEY"): + pytest.skip("No LLM API key set") + + try: + from src.agent.graph import get_agent + + agent = get_agent() + assert agent is not None + except Exception as e: + pytest.skip(f"Agent initialization failed: {e}") + + +def test_mcp_server_module_import(): + """Test MCP server module can be imported.""" + try: + from src.mcp_server import server + + assert hasattr(server, "mcp") + except Exception as e: + pytest.fail(f"MCP server module import failed: {e}")