diff --git a/02-use-cases/text-to-sql-data-analyst/.env.example b/02-use-cases/text-to-sql-data-analyst/.env.example new file mode 100644 index 000000000..e2b4b2dbe --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/.env.example @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# ============================================================================= +# Text-to-SQL Data Analyst Assistant - Environment Variables +# ============================================================================= +# Copy this file as .env and fill in your values: +# cp .env.example .env + +# --- AWS --- +AWS_REGION=us-east-1 +AWS_ACCOUNT_ID=123456789012 + +# --- Glue Data Catalog --- +GLUE_DATABASE_NAME=my_company_demo + +# --- Athena --- +ATHENA_OUTPUT_LOCATION=s3://my-company-text-to-sql-athena/results/ + +# --- S3 Data Lake --- +DEMO_S3_BUCKET=my-company-text-to-sql-data + +# --- AgentCore Memory (optional, configured via agentcore CLI) --- +# AGENTCORE_MEMORY_ID= + +# --- Project name (used for naming AWS resources) --- +PROJECT_NAME=my-company diff --git a/02-use-cases/text-to-sql-data-analyst/.gitignore b/02-use-cases/text-to-sql-data-analyst/.gitignore new file mode 100644 index 000000000..4a98116db --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/.gitignore @@ -0,0 +1,18 @@ +.venv/ +__pycache__/ +*.pyc +*.egg-info/ +.env +cdk.out/ +cdk.context.json +.bedrock_agentcore/ +.bedrock_agentcore.yaml +data/ +.vscode/ +.idea/ +.DS_Store +*.log +lambda_package/ +.pytest_cache/ +htmlcov/ +.coverage diff --git a/02-use-cases/text-to-sql-data-analyst/README.md b/02-use-cases/text-to-sql-data-analyst/README.md new file mode 100644 index 000000000..db3a4af37 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/README.md @@ -0,0 +1,267 @@ +# Text-to-SQL Data Analyst Assistant + +A natural language to SQL data analyst assistant built with Amazon Bedrock AgentCore, Strands Agents SDK, and Amazon Athena. Users ask questions in plain language, and the agent discovers schema from AWS Glue Data Catalog, generates optimized SQL, executes it on Athena, and returns formatted results. + +> **πŸš€ Ready-to-Deploy Agent Web Application**: Use this reference solution to build natural language data query interfaces across different industries. Extend the agent capabilities by adding custom tools, connecting to different data sources, and adapting the business dictionary to your domain. + +## 🎯 Overview + +Text-to-SQL Data Analyst Assistant enables users to: + +- Ask questions about their data in natural language +- Get AI-generated SQL queries executed automatically on Amazon Athena +- View results in a clean web interface with schema visualization +- Benefit from conversational memory (STM + LTM) that learns query patterns across sessions +- Configure tables and business context via YAML β€” no code changes needed + +### Key Features + +- πŸ€– **AI-Powered SQL Generation** using Claude Sonnet 4 via Strands Agents SDK +- πŸ—„οΈ **Automatic Schema Discovery** from AWS Glue Data Catalog +- πŸ”’ **4-Layer Security**: Amazon Bedrock Guardrails β†’ System Prompt β†’ PolicyValidator β†’ AWS Lake Formation +- 🧠 **Dual Memory**: STM (session context) + LTM (learned SQL patterns, TTL 90 days) +- βš™οΈ **YAML-Driven Configuration**: Define tables in `config/tables.yaml`, business context in `config/system_prompt.yaml` +- πŸ—οΈ **CDK Infrastructure**: One-command deployment of Glue, Athena, S3, Lambda, API Gateway, CloudFront +- πŸ”„ **Dual Engine Support**: Works with Amazon Athena (default) and Amazon Redshift +- 🌐 **Web Frontend**: Clean UI with example queries and live schema panel + +## πŸ—οΈ Architecture + +![Text-to-SQL Data Analyst Architecture](docs/architecture.png) + +### Component Details + +#### Frontend (HTML + CSS + JavaScript) +- Query input with natural language support +- Example query buttons for quick exploration +- Schema visualization panel showing available tables and columns +- Results table with execution metrics + +#### Backend (AgentCore Runtime + Strands Agents) +- **discover_schema**: Discovers tables and columns from Glue Data Catalog using keyword-based relevance scoring +- **execute_query**: Validates SQL (SELECT-only), executes on Athena, returns typed results +- **PolicyValidator**: Code-level SQL validation β€” rejects DDL/DML, auto-applies LIMIT +- **System Prompt**: Loaded dynamically from `config/system_prompt.yaml` with business dictionary and few-shot examples + +#### AI Model (Amazon Bedrock) +- Primary: Claude Sonnet 4 β€” `us.anthropic.claude-sonnet-4-20250514-v1:0` + +#### Semantic Layer (AWS Glue + Athena) +- Glue Data Catalog as the schema metastore (tables, columns, types, comments) +- Amazon Athena as the serverless SQL engine over S3 Parquet data +- Tables defined in `config/tables.yaml` and created dynamically by CDK + +### AWS Services + +| Service | Purpose | +|---------|---------| +| Amazon Bedrock AgentCore | Agent runtime with conversational memory (STM + LTM) | +| Claude Sonnet 4 (Amazon Bedrock) | LLM for SQL generation and response formatting | +| AWS Glue Data Catalog | Schema registry / semantic layer | +| Amazon Athena | Serverless SQL engine over S3 | +| Amazon S3 | Data lake (Parquet) + frontend hosting | +| AWS Lambda | Backend orchestrator | +| Amazon API Gateway | REST API with CORS | +| Amazon CloudFront | CDN for frontend + API proxy | +| Amazon Bedrock Guardrails | Content filtering (hate, violence, prompt injection) | +| AWS CDK | Infrastructure as Code | + +## πŸš€ Quick Start + +### Prerequisites + +- Python 3.11+ +- Node.js 18+ (for CDK and AgentCore CLI) +- AWS CLI configured with credentials +- Docker (for Lambda asset bundling) +- AWS account with Amazon Bedrock access (Claude model enabled) +- AWS Permissions: `BedrockAgentCoreFullAccess`, `AmazonBedrockFullAccess` + +### 1. Setup + +```bash +cd 02-use-cases/text-to-sql-data-analyst + +python3 -m venv .venv + +# macOS / Linux +source .venv/bin/activate + +# Windows +.venv\Scripts\activate + +pip install -r requirements.txt + +cp .env.example .env +# Edit .env with your values +``` + +### 2. Define Your Tables + +Edit `config/tables.yaml` with your data structure: + +```yaml +database_name: "my_company_demo" +tables: + - name: customers + description: "Registered customers. Related to sales via customer_id." + columns: + - name: customer_id + type: bigint + comment: "PK - Unique customer identifier" + - name: name + type: string + comment: "Full name" + # ... more columns +``` + +### 3. Configure Business Context + +Edit `config/system_prompt.yaml`: +- `business_dictionary`: Define terms your users commonly use +- `examples`: Add 10-15 relevant SQL query examples (few-shot learning) +- `naming_conventions`: Document your tables and relationships + +### 4. Generate Sample Data (Optional) + +```bash +python scripts/init_demo_data.py +aws s3 cp data/demo/ s3://YOUR-BUCKET/data/ --recursive +``` + +### 5. Deploy Infrastructure + +```bash +cd cdk/ +pip install -r requirements.txt +cdk bootstrap aws://YOUR_ACCOUNT_ID/us-east-1 +cdk deploy --all +``` + +### 6. Deploy AgentCore Agent + +```bash +pip install bedrock-agentcore-starter-toolkit +agentcore configure -e agentcore_agent.py +agentcore launch +``` + +### 7. Test Locally + +```bash +# Start agent locally +python agentcore_agent.py + +# Test (in another terminal) +curl -X POST http://localhost:8080/invocations \ + -H "Content-Type: application/json" \ + -d '{"query": "How many customers do we have?"}' +``` + +## πŸ“‹ Usage + +### Sample Queries + +Once deployed, try these example queries in the web interface: + +| Natural Language Query | What It Does | +|----------------------|--------------| +| "How many customers do we have?" | Counts all records in the customers table | +| "What are the top 10 best-selling products?" | Ranks products by total sales volume | +| "Show me total revenue by month for 2024" | Aggregates sales by month | +| "Which customers spent more than $500?" | Filters customers by total purchase amount | +| "What is the average ticket per customer segment?" | Calculates average sale amount grouped by segment | +| "List products with low stock (less than 50 units)" | Filters products by inventory level | +| "Who are our premium customers?" | Finds customers in the premium segment | + +You can customize the example queries shown in the UI by editing `config/system_prompt.yaml`. + +### Asking Questions + +1. Open the frontend URL (CloudFront output from CDK deploy) +2. Type a natural language question (e.g., "What are the top 10 best-selling products?") +3. Or click one of the example query buttons +4. View the generated SQL and results + +### Adding Tables + +1. Add the table definition in `config/tables.yaml` +2. Upload Parquet data to S3 +3. Redeploy CDK: `cd cdk/ && cdk deploy` +4. Add relevant examples in `config/system_prompt.yaml` + +### Using Redshift Instead of Athena + +The `execute_query` tool supports Redshift out of the box. Set `engine_type="redshift"` and configure connection variables in `.env`. + +## πŸ› οΈ Project Structure + +``` +text-to-sql-data-analyst/ +β”œβ”€β”€ agentcore_agent.py # AgentCore entry point (Strands SDK) +β”œβ”€β”€ config/ +β”‚ β”œβ”€β”€ tables.yaml # βš™οΈ CONFIGURE: Define your tables here +β”‚ └── system_prompt.yaml # βš™οΈ CONFIGURE: Prompt, examples, dictionary +β”œβ”€β”€ src/ +β”‚ β”œβ”€β”€ policy_validator.py # SQL validation (SELECT-only enforcement) +β”‚ └── tools/ +β”‚ β”œβ”€β”€ discover_schema.py # Schema discovery (Glue Data Catalog) +β”‚ └── execute_query.py # Query execution (Athena / Redshift) +β”œβ”€β”€ cdk/ +β”‚ β”œβ”€β”€ app.py # CDK app entry point +β”‚ └── stack.py # AWS infrastructure (reads tables.yaml) +β”œβ”€β”€ scripts/ +β”‚ └── init_demo_data.py # Sample data generator +β”œβ”€β”€ frontend/ # Web interface +β”‚ β”œβ”€β”€ index.html +β”‚ └── static/styles.css +β”œβ”€β”€ tests/ +β”‚ └── test_policy_validator.py # Unit tests +└── docs/ + └── DEEP-DIVE.md # Technical deep dive +``` + +## πŸ”’ Security + +### 4-Layer Validation + +``` +Layer 1: Amazon Bedrock Guardrails β†’ Blocks inappropriate content before LLM +Layer 2: System Prompt β†’ Instructs SELECT-only, LIMIT required +Layer 3: PolicyValidator β†’ Code-level SQL validation (rejects DDL/DML) +Layer 4: AWS Lake Formation β†’ IAM-level permissions (SELECT only on specific tables) +``` + +### Important + +> **⚠️** This sample application is meant for demo purposes and is not production ready. Please make sure to validate the code with your organization's security best practices. + +## πŸ’° Cost Estimate (~1,000 queries/month) + +| Service | Monthly Cost | +|---------|-------------| +| Bedrock (Claude Sonnet 4) | ~$15-30 | +| Athena | ~$2-5 | +| Lambda | ~$1-3 | +| S3 + CloudFront | ~$1-3 | +| AgentCore Runtime | Included with Bedrock | +| **Total** | **~$20-40/month** | + +## 🧹 Cleanup + +```bash +# Destroy CDK stack +cd cdk/ +cdk destroy --all + +# Destroy AgentCore +agentcore destroy +``` + +## πŸ“„ License + +This project is licensed under the Apache-2.0 License. + +## πŸ“š Additional Resources + +For a detailed technical deep dive including request flow analysis, scaling strategies, and cost breakdowns, see [docs/DEEP-DIVE.md](docs/DEEP-DIVE.md). diff --git a/02-use-cases/text-to-sql-data-analyst/agentcore_agent.py b/02-use-cases/text-to-sql-data-analyst/agentcore_agent.py new file mode 100644 index 000000000..329b3425a --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/agentcore_agent.py @@ -0,0 +1,332 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Text-to-SQL Agent β€” Amazon Bedrock AgentCore + +Converts natural language questions to SQL and executes them on Athena. + +CONFIGURATION: +- Edit config/tables.yaml to define your tables +- Edit config/system_prompt.yaml to customize the prompt and examples +- Set environment variables in .env +""" + +import os +import time +import uuid +import yaml +from datetime import datetime +from pathlib import Path + +from strands import tool +from bedrock_agentcore import BedrockAgentCoreApp + +# --- Configuration (from environment variables or .env) --- +GLUE_DATABASE = os.environ.get("GLUE_DATABASE_NAME", "my_company_demo") +AWS_REGION = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") +ATHENA_OUTPUT = os.environ.get( + "ATHENA_OUTPUT_LOCATION", "s3://my-company-text-to-sql-athena/results/" +) +MEMORY_ID = os.environ.get("AGENTCORE_MEMORY_ID", "") +PROJECT_NAME = os.environ.get("PROJECT_NAME", "my-company") + +app = BedrockAgentCoreApp() + +# Lazy-initialized clients +_glue = None +_athena = None + + +def _get_glue(): + global _glue + if _glue is None: + import boto3 + + _glue = boto3.client("glue", region_name=AWS_REGION) + return _glue + + +def _get_athena(): + global _athena + if _athena is None: + import boto3 + + _athena = boto3.client("athena", region_name=AWS_REGION) + return _athena + + +def _load_system_prompt(): + """Load and build the system prompt from config/system_prompt.yaml.""" + config_path = Path(__file__).parent / "config" / "system_prompt.yaml" + try: + with open(config_path, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + except Exception: + config = {} + + tables_info = config.get("naming_conventions", {}).get("tables", []) + tables_str = "\n".join(f" - {t}" for t in tables_info) + + relationships = config.get("naming_conventions", {}).get("relationships", []) + rels_str = "\n".join(f" - {r}" for r in relationships) + + guidelines = config.get("sql_guidelines", []) + guidelines_str = "\n".join(f" - {g}" for g in guidelines) + + business_dict = config.get("business_dictionary", {}) + biz_str = "\n".join(f" - {k}: {v}" for k, v in business_dict.items()) + + return f"""You are an expert SQL assistant for {PROJECT_NAME}. + +CONTEXT: +- Database: {GLUE_DATABASE} in AWS Glue Data Catalog +- Available tables: +{tables_str} +- Relationships: +{rels_str} +- Engine: Amazon Athena (Presto SQL dialect) + +BUSINESS DICTIONARY: +{biz_str} + +CAPABILITIES: +1. Convert natural language questions to SQL queries +2. Use discover_schema() to get table metadata from Glue +3. Generate optimized and safe SQL +4. Execute queries with execute_query() +5. Format results clearly + +WORKFLOW: +1. When you receive a question, first use discover_schema() with relevant keywords +2. Analyze the returned schema to understand available columns +3. Generate an appropriate SQL query +4. Execute with execute_query() +5. Present results clearly and concisely + +SQL RULES: +{guidelines_str} + +IMPORTANT SQL SYNTAX (Presto/Athena): +- Date columns are STRING type in 'YYYY-MM-DD' format +- To extract year: year(date_parse(date_col, '%Y-%m-%d')) +- To extract month: month(date_parse(date_col, '%Y-%m-%d')) + +RESPONSE FORMAT: +- Answer the question directly with the data obtained +- Do NOT include the SQL in your response (it is shown automatically in the frontend) +- Be concise, direct, and friendly""" + + +@tool +def discover_schema(keywords=None): + """ + Discover the schema of available tables in the database. + + Args: + keywords: Optional list of keywords to filter tables + + Returns: + Dictionary with table and column information + """ + try: + response = _get_glue().get_tables(DatabaseName=GLUE_DATABASE) + tables_info = [] + for table in response.get("TableList", []): + name = table["Name"] + if keywords: + kw_lower = [k.lower() for k in keywords] + if not any(kw in name.lower() for kw in kw_lower): + continue + columns = [ + { + "name": c["Name"], + "type": c["Type"], + "comment": c.get("Comment", ""), + } + for c in table.get("StorageDescriptor", {}).get("Columns", []) + ] + tables_info.append( + { + "name": name, + "columns": columns, + "location": table.get("StorageDescriptor", {}).get("Location", ""), + "row_count": table.get("Parameters", {}).get( + "numRows", "unknown" + ), + } + ) + return { + "database": GLUE_DATABASE, + "tables": tables_info, + "total_tables": len(tables_info), + } + except Exception as e: + return {"database": GLUE_DATABASE, "tables": [], "error": str(e)} + + +@tool +def execute_query(sql: str): + """ + Execute a SQL SELECT query on Amazon Athena. + Only SELECT queries are allowed. + + Args: + sql: SQL query to execute (must be SELECT) + + Returns: + Dictionary with query results + """ + try: + sql_upper = sql.strip().upper() + if not sql_upper.startswith("SELECT") and not sql_upper.startswith("WITH"): + return {"success": False, "error": "Only SELECT queries are allowed"} + + forbidden = [ + "DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", "TRUNCATE", + ] + for word in forbidden: + if word in sql_upper: + return {"success": False, "error": f"Operation not allowed: {word}"} + + athena = _get_athena() + response = athena.start_query_execution( + QueryString=sql, + QueryExecutionContext={"Database": GLUE_DATABASE}, + ResultConfiguration={"OutputLocation": ATHENA_OUTPUT}, + ) + qid = response["QueryExecutionId"] + + for _ in range(60): + status_resp = athena.get_query_execution(QueryExecutionId=qid) + state = status_resp["QueryExecution"]["Status"]["State"] + if state in ("SUCCEEDED", "FAILED", "CANCELLED"): + break + time.sleep(0.5) + + if state != "SUCCEEDED": + err = status_resp["QueryExecution"]["Status"].get( + "StateChangeReason", "Query failed" + ) + return {"success": False, "error": err} + + results = athena.get_query_results(QueryExecutionId=qid, MaxResults=1000) + cols = [ + {"name": c["Name"], "type": c["Type"]} + for c in results["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] + ] + rows = [] + for row in results["ResultSet"]["Rows"][1:]: + rows.append( + { + cols[i]["name"]: row["Data"][i].get("VarCharValue") + for i in range(len(cols)) + } + ) + + stats = status_resp["QueryExecution"]["Statistics"] + return { + "success": True, + "sql": sql, + "columns": cols, + "rows": rows, + "row_count": len(rows), + "execution_time_ms": stats.get("TotalExecutionTimeInMillis", 0), + "data_scanned_bytes": stats.get("DataScannedInBytes", 0), + } + except Exception as e: + return {"success": False, "error": str(e)} + + +# Load system prompt at module level +SYSTEM_PROMPT = _load_system_prompt() + + +@app.entrypoint +def invoke(payload, context=None): + """AgentCore Runtime entrypoint.""" + try: + query = ( + payload.get("query", payload.get("prompt", "")) + if isinstance(payload, dict) + else str(payload) + ) + session_id = ( + payload.get("session_id", str(uuid.uuid4())) + if isinstance(payload, dict) + else str(uuid.uuid4()) + ) + user_id = ( + payload.get("user_id", "demo_user") + if isinstance(payload, dict) + else "demo_user" + ) + + if not query: + return { + "success": False, + "error": "No query provided", + "session_id": session_id, + } + + from strands import Agent + + session_manager = None + + # Configure Memory if available + if MEMORY_ID: + try: + from bedrock_agentcore.memory.integrations.strands.config import ( + AgentCoreMemoryConfig, + ) + from bedrock_agentcore.memory.integrations.strands.session_manager import ( + AgentCoreMemorySessionManager, + ) + + memory_config = AgentCoreMemoryConfig( + memory_id=MEMORY_ID, + session_id=session_id, + actor_id=user_id, + ) + session_manager = AgentCoreMemorySessionManager( + agentcore_memory_config=memory_config, + region_name=AWS_REGION, + ) + except Exception: + pass # Memory is optional + + agent = Agent( + name=f"{PROJECT_NAME}TextToSQLAgent", + model="us.anthropic.claude-sonnet-4-20250514-v1:0", + system_prompt=SYSTEM_PROMPT, + tools=[discover_schema, execute_query], + session_manager=session_manager, + ) + + start = time.time() + response = agent(query) + latency = int((time.time() - start) * 1000) + + response_text = str(response) + if hasattr(response, "message") and isinstance(response.message, dict): + content = response.message.get("content", []) + if content and isinstance(content, list) and len(content) > 0: + response_text = content[0].get("text", str(response)) + + return { + "success": True, + "response": response_text, + "session_id": session_id, + "timestamp": datetime.now().isoformat(), + "latency_ms": latency, + } + except Exception as e: + return { + "success": False, + "error": str(e), + "session_id": session_id if "session_id" in dir() else str(uuid.uuid4()), + } + + +if __name__ == "__main__": + app.run() diff --git a/02-use-cases/text-to-sql-data-analyst/cdk/app.py b/02-use-cases/text-to-sql-data-analyst/cdk/app.py new file mode 100644 index 000000000..eba9a8be9 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/cdk/app.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +CDK App β€” Text-to-SQL with Amazon Bedrock AgentCore + +Deploys the complete infrastructure: +- Data Lake (S3) + Glue Data Catalog (from config/tables.yaml) +- Backend (Lambda + API Gateway) +- Frontend (S3 + CloudFront) +- Amazon Bedrock Guardrails +""" + +import os +import aws_cdk as cdk +from stack import TextToSQLStack + +app = cdk.App() + +# Configuration β€” edit these values or set environment variables +PROJECT_NAME = os.environ.get("PROJECT_NAME", "my-company") +AWS_ACCOUNT = os.environ.get("AWS_ACCOUNT_ID", "123456789012") +AWS_REGION = os.environ.get("AWS_REGION", "us-east-1") + +env = cdk.Environment(account=AWS_ACCOUNT, region=AWS_REGION) + +TextToSQLStack( + app, + f"{PROJECT_NAME}-TextToSQL", + project_name=PROJECT_NAME, + env=env, + description=f"Text-to-SQL GenAI Stack for {PROJECT_NAME}", +) + +cdk.Tags.of(app).add("Project", f"{PROJECT_NAME}-TextToSQL") +cdk.Tags.of(app).add("ManagedBy", "CDK") + +app.synth() diff --git a/02-use-cases/text-to-sql-data-analyst/cdk/requirements.txt b/02-use-cases/text-to-sql-data-analyst/cdk/requirements.txt new file mode 100644 index 000000000..d64ac4534 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/cdk/requirements.txt @@ -0,0 +1,6 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +aws-cdk-lib>=2.133.0 +constructs>=10.0.0,<11.0.0 +pyyaml>=6.0 diff --git a/02-use-cases/text-to-sql-data-analyst/cdk/stack.py b/02-use-cases/text-to-sql-data-analyst/cdk/stack.py new file mode 100644 index 000000000..34bfdce8e --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/cdk/stack.py @@ -0,0 +1,339 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +CDK Stack β€” Text-to-SQL with Amazon Bedrock AgentCore + +Reads table definitions from config/tables.yaml and deploys: +- S3 Data Lake + Athena results bucket +- Glue Database + Tables (dynamic from YAML) +- Amazon Bedrock Guardrails +- Lambda + API Gateway (backend) +- S3 + CloudFront (frontend) +""" + +import yaml +from pathlib import Path +from constructs import Construct +from aws_cdk import ( + Stack, Duration, RemovalPolicy, CfnOutput, + aws_s3 as s3, + aws_s3_deployment as s3_deploy, + aws_lambda as lambda_, + aws_apigateway as apigw, + aws_iam as iam, + aws_glue as glue, + aws_cloudfront as cloudfront, + aws_cloudfront_origins as origins, + aws_bedrock as bedrock, +) + + +class TextToSQLStack(Stack): + + def __init__( + self, scope: Construct, construct_id: str, project_name: str, **kwargs + ) -> None: + super().__init__(scope, construct_id, **kwargs) + self.project_name = project_name.lower().replace(" ", "-") + + # Load table definitions from YAML + self.tables_config = self._load_tables_config() + self.database_name = self.tables_config.get( + "database_name", f"{self.project_name.replace('-', '_')}_demo" + ) + + self._create_data_lake() + self._create_glue_catalog() + self._create_guardrails() + self._create_backend() + self._create_frontend() + self._create_outputs() + + def _load_tables_config(self): + """Load config/tables.yaml with table definitions.""" + config_path = Path(__file__).parent.parent / "config" / "tables.yaml" + try: + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + except Exception as e: + print(f"Warning: Could not load tables.yaml: {e}") + return {"database_name": "demo", "tables": []} + + def _create_data_lake(self): + self.data_bucket = s3.Bucket( + self, "DataLakeBucket", + bucket_name=f"{self.project_name}-text-to-sql-data", + removal_policy=RemovalPolicy.DESTROY, + auto_delete_objects=True, + encryption=s3.BucketEncryption.S3_MANAGED, + block_public_access=s3.BlockPublicAccess.BLOCK_ALL, + ) + self.athena_results_bucket = s3.Bucket( + self, "AthenaResultsBucket", + bucket_name=f"{self.project_name}-text-to-sql-athena", + removal_policy=RemovalPolicy.DESTROY, + auto_delete_objects=True, + lifecycle_rules=[ + s3.LifecycleRule(expiration=Duration.days(7), enabled=True) + ], + ) + + def _create_glue_catalog(self): + """Create Glue Database and Tables dynamically from tables.yaml.""" + db_description = self.tables_config.get( + "database_description", + f"Text-to-SQL database for {self.project_name}", + ) + s3_prefix = self.tables_config.get("s3_data_prefix", "data") + + self.glue_database = glue.CfnDatabase( + self, "GlueDatabase", + catalog_id=self.account, + database_input=glue.CfnDatabase.DatabaseInputProperty( + name=self.database_name, + description=db_description, + ), + ) + + # Create tables dynamically from YAML + for i, table_def in enumerate(self.tables_config.get("tables", [])): + table_name = table_def["name"] + columns = [ + glue.CfnTable.ColumnProperty( + name=col["name"], + type=col["type"], + comment=col.get("comment", ""), + ) + for col in table_def.get("columns", []) + ] + + table = glue.CfnTable( + self, f"Table{i}_{table_name}", + catalog_id=self.account, + database_name=self.database_name, + table_input=glue.CfnTable.TableInputProperty( + name=table_name, + description=table_def.get("description", ""), + table_type="EXTERNAL_TABLE", + parameters={"classification": "parquet"}, + storage_descriptor=glue.CfnTable.StorageDescriptorProperty( + location=f"s3://{self.data_bucket.bucket_name}/{s3_prefix}/{table_name}/", + input_format="org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + output_format="org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + serde_info=glue.CfnTable.SerdeInfoProperty( + serialization_library="org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + ), + columns=columns, + ), + ), + ) + table.add_dependency(self.glue_database) + + def _create_guardrails(self): + """Create Amazon Bedrock Guardrails for content filtering.""" + self.guardrail = bedrock.CfnGuardrail( + self, "ContentGuardrail", + name=f"{self.project_name}-content-guardrail", + description="Guardrail to filter inappropriate content in queries", + blocked_input_messaging="Sorry, your query contains content that cannot be processed. Please rephrase your question focusing on data queries.", + blocked_outputs_messaging="Sorry, I cannot generate that response. Please try a different query.", + content_policy_config=bedrock.CfnGuardrail.ContentPolicyConfigProperty( + filters_config=[ + bedrock.CfnGuardrail.ContentFilterConfigProperty( + type=t, input_strength="HIGH", output_strength="HIGH" + ) + for t in ["HATE", "SEXUAL", "VIOLENCE", "MISCONDUCT", "INSULTS"] + ] + ), + topic_policy_config=bedrock.CfnGuardrail.TopicPolicyConfigProperty( + topics_config=[ + bedrock.CfnGuardrail.TopicConfigProperty( + name="Politics", type="DENY", + definition="Discussions about political parties, elections, or political ideologies.", + examples=[ + "What do you think about the president?", + "Which is the best political party?", + ], + ), + bedrock.CfnGuardrail.TopicConfigProperty( + name="Religion", type="DENY", + definition="Discussions about specific religions or theological debates.", + examples=[ + "Which is the best religion?", + "Does God exist?", + ], + ), + bedrock.CfnGuardrail.TopicConfigProperty( + name="Violence", type="DENY", + definition="Discussions about violence, weapons, or illegal activities.", + examples=[ + "How to make a bomb?", + "How to get weapons?", + ], + ), + ], + ), + word_policy_config=bedrock.CfnGuardrail.WordPolicyConfigProperty( + managed_word_lists_config=[ + bedrock.CfnGuardrail.ManagedWordsConfigProperty(type="PROFANITY") + ] + ), + ) + self.guardrail_version = bedrock.CfnGuardrailVersion( + self, "GuardrailVersion", + guardrail_identifier=self.guardrail.attr_guardrail_id, + description="Version 1", + ) + + def _create_backend(self): + """Create Lambda + API Gateway.""" + lambda_role = iam.Role( + self, "LambdaRole", + assumed_by=iam.ServicePrincipal("lambda.amazonaws.com"), + managed_policies=[ + iam.ManagedPolicy.from_aws_managed_policy_name( + "service-role/AWSLambdaBasicExecutionRole" + ) + ], + ) + # Permissions for Glue, Athena, S3, Amazon Bedrock + for actions, resources in [ + ( + ["glue:GetDatabase", "glue:GetDatabases", "glue:GetTable", "glue:GetTables"], + ["*"], + ), + ( + [ + "athena:StartQueryExecution", "athena:GetQueryExecution", + "athena:GetQueryResults", "athena:StopQueryExecution", + ], + ["*"], + ), + ( + ["s3:GetObject", "s3:PutObject", "s3:ListBucket", "s3:GetBucketLocation"], + [ + self.data_bucket.bucket_arn, + f"{self.data_bucket.bucket_arn}/*", + self.athena_results_bucket.bucket_arn, + f"{self.athena_results_bucket.bucket_arn}/*", + ], + ), + ( + ["bedrock:InvokeModel"], + ["arn:aws:bedrock:*::foundation-model/anthropic.claude-*"], + ), + ( + ["bedrock:ApplyGuardrail", "bedrock:GetGuardrail"], + [self.guardrail.attr_guardrail_arn], + ), + ]: + lambda_role.add_to_policy( + iam.PolicyStatement(actions=actions, resources=resources) + ) + + self.backend_lambda = lambda_.Function( + self, "BackendLambda", + function_name=f"{self.project_name}-text-to-sql-api", + runtime=lambda_.Runtime.PYTHON_3_11, + handler="lambda_handler.handler", + code=lambda_.Code.from_asset("../lambda_package"), + role=lambda_role, + timeout=Duration.seconds(60), + memory_size=512, + environment={ + "GLUE_DATABASE_NAME": self.database_name, + "ATHENA_OUTPUT_LOCATION": f"s3://{self.athena_results_bucket.bucket_name}/results/", + "PROJECT_NAME": self.project_name, + "GUARDRAIL_ID": self.guardrail.attr_guardrail_id, + "GUARDRAIL_VERSION": "DRAFT", + }, + ) + + self.api = apigw.RestApi( + self, "TextToSQLApi", + rest_api_name=f"{self.project_name}-text-to-sql-api", + description="Text-to-SQL GenAI API", + default_cors_preflight_options=apigw.CorsOptions( + allow_origins=apigw.Cors.ALL_ORIGINS, + allow_methods=apigw.Cors.ALL_METHODS, + allow_headers=["Content-Type", "Authorization"], + ), + ) + api_resource = self.api.root.add_resource("api") + for path in ["query", "examples", "health"]: + method = "POST" if path == "query" else "GET" + api_resource.add_resource(path).add_method( + method, apigw.LambdaIntegration(self.backend_lambda) + ) + + def _create_frontend(self): + """Create S3 + CloudFront for the frontend.""" + self.frontend_bucket = s3.Bucket( + self, "FrontendBucket", + bucket_name=f"{self.project_name}-text-to-sql-frontend", + removal_policy=RemovalPolicy.DESTROY, + auto_delete_objects=True, + block_public_access=s3.BlockPublicAccess.BLOCK_ALL, + ) + oai = cloudfront.OriginAccessIdentity( + self, "OAI", comment=f"OAI for {self.project_name}" + ) + self.frontend_bucket.grant_read(oai) + + self.distribution = cloudfront.Distribution( + self, "Distribution", + default_behavior=cloudfront.BehaviorOptions( + origin=origins.S3Origin( + self.frontend_bucket, origin_access_identity=oai + ), + viewer_protocol_policy=cloudfront.ViewerProtocolPolicy.REDIRECT_TO_HTTPS, + cache_policy=cloudfront.CachePolicy.CACHING_DISABLED, + ), + additional_behaviors={ + "/api/*": cloudfront.BehaviorOptions( + origin=origins.RestApiOrigin(self.api), + viewer_protocol_policy=cloudfront.ViewerProtocolPolicy.HTTPS_ONLY, + cache_policy=cloudfront.CachePolicy.CACHING_DISABLED, + allowed_methods=cloudfront.AllowedMethods.ALLOW_ALL, + ) + }, + default_root_object="index.html", + error_responses=[ + cloudfront.ErrorResponse( + http_status=404, + response_http_status=200, + response_page_path="/index.html", + ) + ], + ) + s3_deploy.BucketDeployment( + self, "DeployFrontend", + sources=[s3_deploy.Source.asset("../frontend")], + destination_bucket=self.frontend_bucket, + distribution=self.distribution, + distribution_paths=["/*"], + ) + + def _create_outputs(self): + CfnOutput( + self, "FrontendURL", + value=f"https://{self.distribution.distribution_domain_name}", + description="Frontend URL", + ) + CfnOutput(self, "ApiURL", value=self.api.url, description="API Gateway URL") + CfnOutput( + self, "DataBucket", + value=self.data_bucket.bucket_name, + description="S3 Data Lake bucket", + ) + CfnOutput( + self, "GlueDatabase", + value=self.database_name, + description="Glue Database name", + ) + CfnOutput( + self, "GuardrailId", + value=self.guardrail.attr_guardrail_id, + description="Amazon Bedrock Guardrail ID", + ) diff --git a/02-use-cases/text-to-sql-data-analyst/config/system_prompt.yaml b/02-use-cases/text-to-sql-data-analyst/config/system_prompt.yaml new file mode 100644 index 000000000..849bc6ad9 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/config/system_prompt.yaml @@ -0,0 +1,153 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# ============================================================================= +# System Prompt Configuration β€” Text-to-SQL Agent +# ============================================================================= +# INSTRUCTIONS: +# 1. Update business_dictionary with your domain terms +# 2. Update naming_conventions with your tables and columns +# 3. Add relevant SQL query examples (few-shot learning) +# 4. Adjust sql_guidelines if using Redshift instead of Athena + +sql_dialect: "athena" # "athena" or "redshift" + +# Business dictionary: define terms your users commonly use. +# The agent uses these to better understand queries. +business_dictionary: + active_customer: "Customer who has made at least one purchase in the last 90 days" + average_ticket: "Average sale amount per transaction" + top_product: "Product with the highest sales volume in units" + profitable_category: "Category with the highest profit margin" + premium_customer: "Customer in the 'premium' segment with purchases over $100" + recent_sale: "Sale made in the last 30 days" + low_stock: "Product with fewer than 50 units in inventory" + +# Naming conventions: helps the agent map queries to tables +naming_conventions: + tables: + - "customers: Registered customer information" + - "products: Product catalog" + - "sales: Sales transaction records" + columns: + - "IDs end with _id (customer_id, product_id, sale_id)" + - "Dates: registration_date, sale_date (YYYY-MM-DD format)" + - "Amounts: total_amount, price (in USD)" + - "Quantities: quantity, stock" + relationships: + - "sales.customer_id -> customers.customer_id" + - "sales.product_id -> products.product_id" + +# Query examples: the agent uses these as reference (few-shot learning). +# Add at least 10-15 examples relevant to your business. +examples: + - natural_query: "How many customers do we have?" + sql: "SELECT COUNT(*) as total_customers FROM customers" + explanation: "Counts all records in the customers table" + + - natural_query: "How many products are in the catalog?" + sql: "SELECT COUNT(*) as total_products FROM products" + explanation: "Counts all available products" + + - natural_query: "How many sales have been made?" + sql: "SELECT COUNT(*) as total_sales FROM sales" + explanation: "Counts all sales transactions" + + - natural_query: "What are the top 10 best-selling products?" + sql: | + SELECT p.name, SUM(s.quantity) as total_sold + FROM sales s + JOIN products p ON s.product_id = p.product_id + GROUP BY p.name + ORDER BY total_sold DESC + LIMIT 10 + explanation: "Groups sales by product and orders by total quantity" + + - natural_query: "What is the average ticket by customer segment?" + sql: | + SELECT c.segment, ROUND(AVG(s.total_amount), 2) as avg_ticket + FROM sales s + JOIN customers c ON s.customer_id = c.customer_id + GROUP BY c.segment + ORDER BY avg_ticket DESC + explanation: "Calculates average total_amount grouped by segment" + + - natural_query: "Which product category generates the most revenue?" + sql: | + SELECT p.category, ROUND(SUM(s.total_amount), 2) as total_revenue + FROM sales s + JOIN products p ON s.product_id = p.product_id + GROUP BY p.category + ORDER BY total_revenue DESC + explanation: "Sums sale amounts by category" + + - natural_query: "Which products have low stock?" + sql: | + SELECT name, stock, category + FROM products + WHERE stock < 50 + ORDER BY stock ASC + explanation: "Filters products with fewer than 50 units" + + - natural_query: "What is the total sales amount?" + sql: "SELECT ROUND(SUM(total_amount), 2) as total_amount FROM sales" + explanation: "Sums all sale amounts" + + - natural_query: "What is the most used payment method?" + sql: | + SELECT payment_method, COUNT(*) as num_transactions + FROM sales + GROUP BY payment_method + ORDER BY num_transactions DESC + explanation: "Counts transactions by payment method" + + - natural_query: "How many customers do we have per country?" + sql: | + SELECT country, COUNT(*) as total_customers + FROM customers + GROUP BY country + ORDER BY total_customers DESC + explanation: "Counts customers grouped by country" + + - natural_query: "Which suppliers have the most products?" + sql: | + SELECT supplier, COUNT(*) as num_products + FROM products + GROUP BY supplier + ORDER BY num_products DESC + explanation: "Counts products by supplier" + +# Agent workflow +workflow: + steps: + - step: 1 + action: "Analyze the user's question" + - step: 2 + action: "Discover relevant schema (discover_schema)" + - step: 3 + action: "Generate SQL with context" + - step: 4 + action: "Validate SQL (PolicyValidator)" + - step: 5 + action: "Execute query (execute_query)" + - step: 6 + action: "Format and present results" + +# SQL rules the agent must follow +sql_guidelines: + - "Always use SELECT, never modification commands" + - "Add LIMIT when appropriate" + - "Use explicit JOINs (not implicit joins)" + - "Use table aliases (e.g., FROM sales s)" + - "Order results when relevant" + - "Use ROUND() for decimals" + - "Dates are STRING in YYYY-MM-DD format (use date_parse in Athena)" + +# Response format +response_format: + include_sql: true + include_explanation: true + max_rows_display: 50 + date_format: "YYYY-MM-DD" + number_format: "1,000.00" + language: "english" diff --git a/02-use-cases/text-to-sql-data-analyst/config/tables.yaml b/02-use-cases/text-to-sql-data-analyst/config/tables.yaml new file mode 100644 index 000000000..1987cb523 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/config/tables.yaml @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# ============================================================================= +# Table Definitions for Text-to-SQL +# ============================================================================= +# This file defines the tables that will be created in AWS Glue Data Catalog +# and that the agent can query via Athena. +# +# INSTRUCTIONS: +# 1. Replace the example tables with your own +# 2. Define columns with clear types and descriptions +# 3. Document relationships (FK) in column comments +# 4. Redeploy CDK after changes: cd cdk/ && cdk deploy + +# Glue database name (no spaces, use underscores) +database_name: "my_company_demo" + +# Database description +database_description: "Sample database for natural language queries" + +# Base S3 location for data (Parquet files) +s3_data_prefix: "data" + +tables: + # --- Table 1: Customers --- + - name: customers + description: > + Registered customers on the platform. + Related to sales via customer_id. + columns: + - name: customer_id + type: bigint + comment: "PK - Unique customer identifier" + - name: name + type: string + comment: "Full name" + - name: email + type: string + comment: "Email address" + - name: registration_date + type: string + comment: "Registration date in YYYY-MM-DD format" + - name: country + type: string + comment: "Country of residence" + - name: segment + type: string + comment: "Customer segment: premium, regular, new" + + # --- Table 2: Products --- + - name: products + description: > + Product catalog available for sale. + Related to sales via product_id. + columns: + - name: product_id + type: int + comment: "PK - Unique product identifier" + - name: name + type: string + comment: "Product name" + - name: category + type: string + comment: "Product category" + - name: price + type: double + comment: "Sale price in USD" + - name: stock + type: int + comment: "Available inventory units. Low < 50" + - name: supplier + type: string + comment: "Supplier name" + + # --- Table 3: Sales --- + - name: sales + description: > + Sales transaction records. + FK: customer_id -> customers.customer_id, product_id -> products.product_id + columns: + - name: sale_id + type: bigint + comment: "PK - Unique sale identifier" + - name: customer_id + type: bigint + comment: "FK -> customers.customer_id" + - name: product_id + type: bigint + comment: "FK -> products.product_id" + - name: sale_date + type: string + comment: "Sale date in YYYY-MM-DD format" + - name: quantity + type: int + comment: "Units purchased" + - name: total_amount + type: double + comment: "Total amount in USD (with discount applied)" + - name: discount + type: double + comment: "Discount percentage applied (0.0 to 1.0)" + - name: payment_method + type: string + comment: "Payment method: credit_card, transfer, cash" diff --git a/02-use-cases/text-to-sql-data-analyst/docs/DEEP-DIVE.md b/02-use-cases/text-to-sql-data-analyst/docs/DEEP-DIVE.md new file mode 100644 index 000000000..a0e9167e4 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/docs/DEEP-DIVE.md @@ -0,0 +1,276 @@ +# Text-to-SQL with Amazon Bedrock AgentCore β€” Technical Deep Dive + +## 1. Architecture Overview + +**Main flow:** User β†’ CloudFront β†’ API Gateway β†’ Lambda β†’ Amazon Bedrock Guardrails β†’ Claude Sonnet 4 (Strands SDK) β†’ Glue/Athena/S3 + +**Dual memory:** STM (active session) + Semantic Memory (SQL patterns, TTL 90 days) + +**Observability:** CloudWatch (logs + custom metrics) + +### AWS Services + +| Service | Purpose | Estimated Cost | +|---------|---------|---------------| +| CloudFront | CDN for static frontend | ~$0.01/mo (free tier) | +| API Gateway | REST API with CORS | ~$3.50/million requests | +| Lambda | Orchestrator (Python 3.11, 512MB, ARM64) | ~$0.20/million invocations | +| Amazon Bedrock AgentCore | Agent runtime with memory | Included with Amazon Bedrock | +| Claude Sonnet 4 | LLM for SQL generation + responses | ~$3/1M input + $15/1M output tokens | +| Glue Data Catalog | Table and column metastore | Free (first 1M objects) | +| Athena | Serverless SQL engine over S3 | $5/TB scanned | +| S3 | Data Lake (Parquet columnar) | ~$0.023/GB/mo | +| CloudWatch | Logs, metrics, observability | ~$0.50/GB ingested | + +### Estimated Monthly Cost (~1,000 queries/month) + +| Component | Calculation | Cost | +|-----------|------------|------| +| Claude Sonnet 4 | ~900 tokens/query Γ— 1,000 = 900K tokens | ~$5.40 | +| Athena | ~2KB/query Γ— 1,000 = 2MB scanned | ~$0.01 | +| Lambda | 1,000 invocations Γ— 15s Γ— 512MB | ~$0.10 | +| API Gateway | 1,000 requests | ~$0.004 | +| Keep-alive | 14,400 invocations/mo (every 3 min) | ~$13.00 | +| **Total** | | **~$18.50/mo** | + +--- + +## 2. Detailed Request Flow + +### 2.1 First query in session (full flow ~12-15s) + +``` +Time Step Who decides Duration +────── ──── ─────────── ──────── + 0ms User submits question Frontend - + 50ms CloudFront β†’ API Gateway Infrastructure ~50ms +100ms API Gateway β†’ Lambda AWS ~50ms +150ms Lambda β†’ AgentCore Runtime invoke_agent_runtime ~100ms +250ms AgentCore loads STM Memory AgentCore ~200ms + (session context) +450ms Claude receives question Claude Sonnet 4 - + + system prompt + memory +450ms Claude decides: "I need Claude Sonnet 4 ~2,000ms + the DB schema" +2.5s Tool call: discover_schema() Claude β†’ Glue ~600ms +3.1s Claude receives schema Claude Sonnet 4 ~3,000ms + (3 tables, 23 columns) + Generates optimized SQL +6.1s Tool call: execute_query() Claude β†’ Athena ~700ms +6.8s Claude receives results Claude Sonnet 4 ~2,000ms + Formats response with + data + context +8.8s AgentCore saves to STM AgentCore Memory ~200ms +9.0s Response β†’ Lambda AgentCore - +9.1s Lambda extracts SQL + metrics Lambda ~200ms +9.3s Response β†’ API GW β†’ User Infrastructure ~50ms +``` + +### 2.2 Repeated query in same session (STM ~3-5s) + +``` +Time Step Who decides Duration +────── ──── ─────────── ──────── + 0ms User repeats question Frontend - +150ms Lambda β†’ AgentCore Runtime invoke_agent_runtime ~150ms +300ms AgentCore loads STM Memory AgentCore ~200ms + (includes previous Q&A) +500ms Claude sees in memory: Claude Sonnet 4 ~2,500ms + "I already answered this" + Responds from context + WITHOUT calling tools +3.0s Direct response AgentCore β†’ Lambda ~100ms +``` + +### 2.3 New query in same session (cached schema ~6-8s) + +``` +Time Step Who decides Duration +────── ──── ─────────── ──────── + 0ms User asks something new Frontend - +150ms Lambda β†’ AgentCore Runtime invoke_agent_runtime ~150ms +300ms AgentCore loads STM Memory AgentCore ~200ms + (includes schema from before) +500ms Claude sees schema in memory Claude Sonnet 4 ~2,500ms + Does NOT call discover_schema() + Generates SQL directly +3.0s Tool call: execute_query() Claude β†’ Athena ~700ms +3.7s Claude formats response Claude Sonnet 4 ~2,000ms +5.7s Response β†’ User Infrastructure ~100ms +``` + +--- + +## 3. Semantic Layer: From Data Lake to Data Lakehouse + +### 3.1 The Semantic Layer in AWS + +The semantic layer transforms a Data Lake (raw storage in S3) into a functional Data Lakehouse, combining S3 flexibility with Data Warehouse structure. + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ AWS SEMANTIC LAYER β€” 3 PILLARS β”‚ +β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ AWS Glue Data β”‚ β”‚ AWS Lake β”‚ β”‚ Amazon Redshift β”‚ β”‚ +β”‚ β”‚ Catalog β”‚ β”‚ Formation β”‚ β”‚ Spectrum β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ Centralized β”‚ β”‚ Governance and β”‚ β”‚ High-performance β”‚ β”‚ +β”‚ β”‚ metastore. β”‚ β”‚ security. β”‚ β”‚ semantic layer β”‚ β”‚ +β”‚ β”‚ Defines schema β”‚ β”‚ Semantic β”‚ β”‚ over S3. β”‚ β”‚ +β”‚ β”‚ for raw data β”‚ β”‚ permissions: β”‚ β”‚ Query without β”‚ β”‚ +β”‚ β”‚ in S3 (crawlers) β”‚ β”‚ who sees what β”‚ β”‚ importing data. β”‚ β”‚ +β”‚ β”‚ to make it β”‚ β”‚ data from the β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ queryable. β”‚ β”‚ catalog. β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ βœ… INCLUDED β”‚ β”‚ βœ… INCLUDED β”‚ β”‚ ❌ NOT NEEDED β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β”‚ +β”‚ We use Athena as the serverless SQL engine instead of Redshift β”‚ +β”‚ Spectrum, since it requires no cluster and costs $0 at low volumes. β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +### 3.2 How Claude Interprets Semantics + +Claude Sonnet 4 adds a semantic inference layer that doesn't exist in traditional BI tools: + +- **System prompt** (from `config/system_prompt.yaml`): business dictionary, naming conventions, few-shot examples +- **Column name inference**: `total_amount` β†’ money, use SUM/AVG; `category` β†’ use GROUP BY; `sale_date` β†’ temporal, use date_parse +- This differentiates the solution from traditional BI: users ask in natural language and Claude translates intent to correct SQL + +--- + +## 4. Security Layers + +``` +β”Œβ”€ LAYER 1: Amazon Bedrock Guardrails ─────────────────┐ +β”‚ βœ“ Blocks politics, religion, violence, sexual, hate β”‚ +β”‚ βœ“ Blocks prompt injection β”‚ +β”‚ βœ— Evaluates BEFORE Claude processes β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + ↓ +β”Œβ”€ LAYER 2: System Prompt ─────────────────────────────┐ +β”‚ βœ“ SELECT only β”‚ +β”‚ βœ“ Always include LIMIT β”‚ +β”‚ βœ“ Use exact schema names β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + ↓ +β”Œβ”€ LAYER 3: PolicyValidator + execute_query() ─────────┐ +β”‚ βœ“ Rejects DROP, DELETE, INSERT, UPDATE, ALTER, CREATEβ”‚ +β”‚ βœ“ Only executes if starts with SELECT β”‚ +β”‚ βœ“ Athena validates SQL syntax β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + ↓ +β”Œβ”€ LAYER 4: Lake Formation (permissions) ──────────────┐ +β”‚ βœ“ Agent role only has SELECT on specific tables β”‚ +β”‚ βœ“ Cannot access other databases β”‚ +β”‚ βœ“ Cannot create/modify tables β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +--- + +## 5. Scaling for Production + +### 5.1 Production scenario (~50 tables, ~500 users) + +``` +CURRENT (POC) PRODUCTION +──────────── ────────── +3 tables 50+ tables +~6K records Millions of records +1 concurrent user 50-100 concurrent +No authentication Cognito + API Keys +No response cache ElastiCache/DynamoDB cache +Athena on-demand Athena provisioned capacity +Claude Sonnet 4 Sonnet 4 (complex) + Haiku (simple) +``` + +### 5.2 Key production improvements + +| Improvement | Current Latency | Target | How | +|-------------|----------------|--------|-----| +| DynamoDB cache | 12-15s (first) | <500ms (cache hit) | Hash query β†’ cached response with TTL | +| Provisioned Throughput | Variable | Consistent | Reserve Claude capacity in Amazon Bedrock | +| Athena provisioned | ~700ms | ~200ms | Reserved capacity for frequent queries | +| Pre-loaded schema | 600ms (Glue call) | 0ms | Inject schema in system prompt | +| Model routing | 12-15s always | 3-5s (simple) | Haiku for COUNT/simple, Sonnet for JOINs | + +### 5.3 Suggested roadmap + +``` +PHASE 1 (2-3 weeks): Minimum production +──────────────────────────────────────── +βœ“ Connect to real data (existing or new Glue Catalog) +βœ“ Enrich schema with semantic comments +βœ“ Add 15-20 few-shot examples for business queries +βœ“ Cognito for authentication +βœ“ DynamoDB cache for frequent queries +βœ“ Estimate: ~$50-100/mo + +PHASE 2 (2-3 weeks): Optimization +──────────────────────────────────── +βœ“ Model routing (Haiku for simple, Sonnet for complex) +βœ“ Pre-loaded schema in prompt (eliminate discover_schema) +βœ“ Athena provisioned capacity +βœ“ CloudWatch observability dashboard +βœ“ Estimate: ~$150-300/mo (depending on volume) + +PHASE 3 (3-4 weeks): Enterprise +──────────────────────────────────── +βœ“ Multi-tenancy with Lake Formation +βœ“ Row-level security by department +βœ“ Feedback loop (user marks incorrect responses) +βœ“ Alerts and SLAs +βœ“ VPC + PrivateLink +βœ“ Estimate: ~$300-800/mo (depending on users and volume) +``` + +--- + +## 6. Comparison with Alternatives + +| Criteria | This solution (AgentCore) | Amazon Q Business | Tableau Ask Data | Power BI Copilot | +|----------|--------------------------|-------------------|------------------|------------------| +| Setup | ~2 days | ~1-2 weeks | Requires Tableau | Requires Power BI | +| Customization | Full (Python code) | Limited | Limited | Limited | +| LLM Model | Claude Sonnet 4 (choice) | AWS proprietary | Salesforce proprietary | GPT-4 (Microsoft) | +| Base cost | ~$20/mo | ~$25/user/mo | ~$75/user/mo | ~$30/user/mo | +| Native Data Lake | Yes (Glue + Athena + S3) | Yes | No (needs connector) | No (needs connector) | +| Guardrails | Amazon Bedrock Guardrails | Basic | No | Basic | +| Conversational memory | STM + LTM (AgentCore) | Yes | No | Limited | + +--- + +## 7. How to Customize This Template + +### Step 1: Define your tables +Edit `config/tables.yaml` with your real data structure. + +### Step 2: Configure the prompt +Edit `config/system_prompt.yaml`: +- `business_dictionary`: your business terms +- `examples`: 10-15 relevant SQL queries +- `naming_conventions`: your tables and relationships + +### Step 3: Generate test data +```bash +python scripts/init_demo_data.py +aws s3 cp data/demo/ s3://YOUR-BUCKET/data/ --recursive +``` + +### Step 4: Deploy +```bash +cd cdk/ +pip install -r requirements.txt +cdk bootstrap aws://YOUR_ACCOUNT/us-east-1 +cdk deploy --all +``` + +### Step 5: Configure AgentCore +```bash +npm install -g @aws/agentcore-cli +agentcore init +agentcore deploy --region us-east-1 +``` diff --git a/02-use-cases/text-to-sql-data-analyst/docs/architecture.png b/02-use-cases/text-to-sql-data-analyst/docs/architecture.png new file mode 100644 index 000000000..8c5105ca4 Binary files /dev/null and b/02-use-cases/text-to-sql-data-analyst/docs/architecture.png differ diff --git a/02-use-cases/text-to-sql-data-analyst/frontend/index.html b/02-use-cases/text-to-sql-data-analyst/frontend/index.html new file mode 100644 index 000000000..a84a1e97a --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/frontend/index.html @@ -0,0 +1,240 @@ + + + + + + + + Text-to-SQL GenAI + + + + + +
+
+

Text-to-SQL GenAI

+

Query your data in natural language

+
+ +
+
+ + +
+ +
+

Example Queries

+
+ + + + + + +
+
+
+ + + + + +
+

Available Tables

+

Schema of the tables you can query

+ +
+
+
+ +

customers

+
+

Registered customers

+
+
BIGINT customer_id PK
+
STRING name
+
STRING email
+
STRING registration_date YYYY-MM-DD
+
STRING country
+
STRING segment premium, regular, new
+
+
+ +
+
+ +

products

+
+

Product catalog

+
+
INT product_id PK
+
STRING name
+
STRING category
+
DOUBLE price USD
+
INT stock Low < 50
+
STRING supplier
+
+
+ +
+
+ +

sales

+
+

Sales transactions

+
+
BIGINT sale_id PK
+
BIGINT customer_id FK
+
BIGINT product_id FK
+
STRING sale_date YYYY-MM-DD
+
INT quantity
+
DOUBLE total_amount USD
+
DOUBLE discount 0.0-1.0
+
STRING payment_method
+
+
+
+ +
+
+ +
Database: my_company_demo
+
+
+ +
Engine: Amazon Athena (Presto SQL)
+
+
+ +
Format: Parquet (columnar)
+
+
+
+
+ + + + diff --git a/02-use-cases/text-to-sql-data-analyst/frontend/static/styles.css b/02-use-cases/text-to-sql-data-analyst/frontend/static/styles.css new file mode 100644 index 000000000..b1c7d86e6 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/frontend/static/styles.css @@ -0,0 +1,257 @@ +/* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. */ +/* SPDX-License-Identifier: Apache-2.0 */ + +/* ============================================================================= + Text-to-SQL Sample β€” Styles + Customize colors and branding for your project + ============================================================================= */ + +:root { + --primary: #667eea; + --primary-dark: #5a67d8; + --bg: #f7f8fc; + --card-bg: #ffffff; + --text: #2d3748; + --text-light: #718096; + --border: #e2e8f0; + --success: #48bb78; + --error: #e53e3e; + --radius: 12px; +} + +* { margin: 0; padding: 0; box-sizing: border-box; } + +body { + font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; + background: var(--bg); + color: var(--text); + line-height: 1.6; +} + +.container { + max-width: 1100px; + margin: 0 auto; + padding: 30px 20px; +} + +.header { + text-align: center; + margin-bottom: 30px; +} +.header h1 { + font-size: 2rem; + font-weight: 700; + background: linear-gradient(135deg, var(--primary), #764ba2); + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; +} +.subtitle { + color: var(--text-light); + margin-top: 5px; +} + +.query-section { + background: var(--card-bg); + border-radius: var(--radius); + padding: 25px; + box-shadow: 0 1px 3px rgba(0,0,0,0.1); + margin-bottom: 25px; +} +.input-group { + display: flex; + gap: 12px; + align-items: flex-start; +} +textarea { + flex: 1; + padding: 12px 16px; + border: 2px solid var(--border); + border-radius: 8px; + font-family: inherit; + font-size: 0.95rem; + resize: vertical; + transition: border-color 0.2s; +} +textarea:focus { + outline: none; + border-color: var(--primary); +} +.btn-primary { + padding: 12px 24px; + background: linear-gradient(135deg, var(--primary), #764ba2); + color: white; + border: none; + border-radius: 8px; + font-weight: 600; + cursor: pointer; + transition: transform 0.1s, opacity 0.2s; + white-space: nowrap; +} +.btn-primary:hover { transform: translateY(-1px); opacity: 0.95; } +.btn-primary:disabled { opacity: 0.6; cursor: not-allowed; transform: none; } + +.spinner { + display: inline-block; + width: 14px; height: 14px; + border: 2px solid rgba(255,255,255,0.3); + border-top-color: white; + border-radius: 50%; + animation: spin 0.6s linear infinite; +} +@keyframes spin { to { transform: rotate(360deg); } } + +.btn-loading { display: inline-flex; align-items: center; gap: 6px; } + +.examples-section { margin-top: 20px; } +.examples-section h3 { font-size: 0.9rem; color: var(--text-light); margin-bottom: 10px; } +.examples-grid { + display: flex; + flex-wrap: wrap; + gap: 8px; +} +.example-btn { + padding: 6px 14px; + background: #edf2f7; + border: 1px solid var(--border); + border-radius: 20px; + font-size: 0.82rem; + cursor: pointer; + transition: background 0.2s; +} +.example-btn:hover { background: #e2e8f0; } + +.results-section { + background: var(--card-bg); + border-radius: var(--radius); + padding: 25px; + box-shadow: 0 1px 3px rgba(0,0,0,0.1); + margin-bottom: 25px; +} +.sql-section { margin-bottom: 20px; } +.sql-section h3, .table-section h3 { font-size: 1rem; margin-bottom: 10px; } +.sql-code { + background: #1a202c; + color: #e2e8f0; + padding: 16px; + border-radius: 8px; + overflow-x: auto; + font-family: 'Fira Code', 'Consolas', monospace; + font-size: 0.85rem; + line-height: 1.5; +} +.results-info { + font-size: 0.85rem; + color: var(--text-light); + margin-bottom: 10px; +} +.table-container { overflow-x: auto; } +.results-table { + width: 100%; + border-collapse: collapse; + font-size: 0.85rem; +} +.results-table th { + background: #f7fafc; + padding: 10px 12px; + text-align: left; + font-weight: 600; + border-bottom: 2px solid var(--border); + white-space: nowrap; +} +.results-table td { + padding: 8px 12px; + border-bottom: 1px solid var(--border); +} +.results-table tr:hover { background: #f7fafc; } + +.error-section { + background: #fff5f5; + border: 1px solid #fed7d7; + border-radius: var(--radius); + padding: 20px; + margin-bottom: 25px; +} +.error-section h3 { color: var(--error); } +.error-section p { color: #c53030; margin-top: 8px; } + +.database-tables-section { + background: var(--card-bg); + border-radius: var(--radius); + padding: 25px; + box-shadow: 0 1px 3px rgba(0,0,0,0.1); +} +.section-description { color: var(--text-light); font-size: 0.9rem; margin-bottom: 20px; } +.tables-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); + gap: 16px; + margin-bottom: 20px; +} +.table-card { + border: 1px solid var(--border); + border-radius: 10px; + padding: 16px; + transition: box-shadow 0.2s; +} +.table-card:hover { box-shadow: 0 4px 12px rgba(0,0,0,0.08); } +.table-card-header { + display: flex; + align-items: center; + gap: 10px; + margin-bottom: 8px; +} +.table-icon { font-size: 1.5rem; } +.table-card-header h4 { + font-family: 'Fira Code', monospace; + color: var(--primary); +} +.table-description { font-size: 0.85rem; color: var(--text-light); margin-bottom: 12px; } +.table-columns-list { display: flex; flex-direction: column; gap: 4px; } +.column-item { + font-size: 0.82rem; + padding: 4px 8px; + background: #f7fafc; + border-radius: 4px; + font-family: 'Fira Code', monospace; +} +.col-type { + display: inline-block; + background: var(--primary); + color: white; + padding: 1px 6px; + border-radius: 3px; + font-size: 0.7rem; + font-weight: 600; + margin-right: 4px; +} +.col-comment { + color: var(--text-light); + font-size: 0.75rem; + font-family: 'Inter', sans-serif; + margin-left: 4px; +} + +.database-info { + display: flex; + flex-wrap: wrap; + gap: 12px; + margin-top: 16px; + padding-top: 16px; + border-top: 1px solid var(--border); +} +.info-box { + display: flex; + align-items: center; + gap: 8px; + padding: 8px 14px; + background: #f7fafc; + border-radius: 8px; + font-size: 0.85rem; +} +.info-icon { font-size: 1.2rem; } + +@media (max-width: 768px) { + .input-group { flex-direction: column; } + .tables-grid { grid-template-columns: 1fr; } + .database-info { flex-direction: column; } +} diff --git a/02-use-cases/text-to-sql-data-analyst/pyproject.toml b/02-use-cases/text-to-sql-data-analyst/pyproject.toml new file mode 100644 index 000000000..d0c776ac4 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/pyproject.toml @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short" + +[tool.black] +line-length = 100 +target-version = ["py311"] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "I", "W"] diff --git a/02-use-cases/text-to-sql-data-analyst/requirements.txt b/02-use-cases/text-to-sql-data-analyst/requirements.txt new file mode 100644 index 000000000..78b15642d --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/requirements.txt @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +boto3>=1.34.0 +botocore>=1.34.0 +bedrock-agentcore>=0.1.0 +strands-agents>=1.28.0 +sqlparse>=0.5.0 +pyyaml>=6.0 +python-dotenv>=1.0.0 +pandas>=2.0.0 +pyarrow>=14.0.0 +pytest>=7.4.0 +pytest-cov>=4.1.0 diff --git a/02-use-cases/text-to-sql-data-analyst/scripts/init_demo_data.py b/02-use-cases/text-to-sql-data-analyst/scripts/init_demo_data.py new file mode 100644 index 000000000..53149f93a --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/scripts/init_demo_data.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Sample Data Generator + +Reads config/tables.yaml and generates sample data in Parquet format. +Customize the generate_* functions for your business domain. + +Usage: + python scripts/init_demo_data.py + aws s3 cp data/demo/ s3://YOUR-BUCKET/data/ --recursive +""" + +import os +import random +import yaml +from datetime import datetime, timedelta +from pathlib import Path + +import pandas as pd + +# Configuration +OUTPUT_DIR = "data/demo" +NUM_CUSTOMERS = 1000 +NUM_PRODUCTS = 200 +NUM_SALES = 5000 + + +def load_tables_config(): + """Load table configuration.""" + config_path = Path(__file__).parent.parent / "config" / "tables.yaml" + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def generate_customers(num_records: int) -> pd.DataFrame: + """Generate sample customer data. Customize for your business.""" + print(f"Generating {num_records} customers...") + + first_names = [ + "Alice", "Bob", "Carol", "David", "Emma", "Frank", "Grace", "Henry", + "Iris", "Jack", "Karen", "Leo", "Maria", "Noah", "Olivia", "Paul", + ] + last_names = [ + "Smith", "Johnson", "Williams", "Brown", "Jones", "Garcia", + "Miller", "Davis", "Rodriguez", "Martinez", "Wilson", "Anderson", + ] + countries = ["United States", "Canada", "United Kingdom", "Germany", "France", "Spain"] + segments = {"premium": 0.15, "regular": 0.60, "new": 0.25} + + data = [] + start_date = datetime.now() - timedelta(days=730) + + for i in range(1, num_records + 1): + first = random.choice(first_names) + last = random.choice(last_names) + segment = random.choices( + list(segments.keys()), weights=list(segments.values()) + )[0] + reg_date = start_date + timedelta(days=random.randint(0, 730)) + + data.append({ + "customer_id": i, + "name": f"{first} {last}", + "email": f"{first.lower()}.{last.lower()}{i}@example.com", + "registration_date": reg_date.strftime("%Y-%m-%d"), + "country": random.choice(countries), + "segment": segment, + }) + + return pd.DataFrame(data) + + +def generate_products(num_records: int) -> pd.DataFrame: + """Generate sample product data. Customize for your business.""" + print(f"Generating {num_records} products...") + + categories = { + "electronics": {"min": 50, "max": 1500}, + "clothing": {"min": 15, "max": 200}, + "home": {"min": 10, "max": 500}, + "sports": {"min": 20, "max": 800}, + "food": {"min": 2, "max": 100}, + } + suppliers = [ + "Supplier A", "Supplier B", "Supplier C", "Supplier D", "Supplier E", + ] + + data = [] + pid = 1 + per_cat = num_records // len(categories) + + for cat, config in categories.items(): + for _ in range(per_cat): + data.append({ + "product_id": pid, + "name": f"{cat.title()} Product {pid}", + "category": cat, + "price": round(random.uniform(config["min"], config["max"]), 2), + "stock": random.randint(0, 500), + "supplier": random.choice(suppliers), + }) + pid += 1 + + return pd.DataFrame(data) + + +def generate_sales( + num_records: int, num_customers: int, num_products: int +) -> pd.DataFrame: + """Generate sample sales data.""" + print(f"Generating {num_records} sales...") + + payment_methods = {"credit_card": 0.60, "transfer": 0.30, "cash": 0.10} + data = [] + start_date = datetime.now() - timedelta(days=365) + + for i in range(1, num_records + 1): + sale_date = start_date + timedelta(days=random.randint(0, 365)) + quantity = random.randint(1, 10) + base_price = random.uniform(20, 500) + discount = round(random.uniform(0, 0.20), 2) + total = round(base_price * quantity * (1 - discount), 2) + method = random.choices( + list(payment_methods.keys()), + weights=list(payment_methods.values()), + )[0] + + data.append({ + "sale_id": i, + "customer_id": random.randint(1, num_customers), + "product_id": random.randint(1, num_products), + "sale_date": sale_date.strftime("%Y-%m-%d"), + "quantity": quantity, + "total_amount": total, + "discount": discount, + "payment_method": method, + }) + + return pd.DataFrame(data) + + +def save_parquet(df: pd.DataFrame, table_name: str): + os.makedirs(OUTPUT_DIR, exist_ok=True) + path = os.path.join(OUTPUT_DIR, f"{table_name}.parquet") + df.to_parquet(path, engine="pyarrow", index=False) + print(f" -> {path} ({len(df)} records)") + + +def main(): + config = load_tables_config() + db_name = config.get("database_name", "demo") + print(f"\nGenerating data for: {db_name}\n") + + customers = generate_customers(NUM_CUSTOMERS) + products = generate_products(NUM_PRODUCTS) + sales = generate_sales(NUM_SALES, NUM_CUSTOMERS, NUM_PRODUCTS) + + save_parquet(customers, "customers") + save_parquet(products, "products") + save_parquet(sales, "sales") + + bucket = os.environ.get("DEMO_S3_BUCKET", "my-company-text-to-sql-data") + prefix = config.get("s3_data_prefix", "data") + + print(f"\nData generated in {OUTPUT_DIR}/") + print(f"\nNext step β€” upload to S3:") + print(f" aws s3 cp {OUTPUT_DIR}/customers.parquet s3://{bucket}/{prefix}/customers/customers.parquet") + print(f" aws s3 cp {OUTPUT_DIR}/products.parquet s3://{bucket}/{prefix}/products/products.parquet") + print(f" aws s3 cp {OUTPUT_DIR}/sales.parquet s3://{bucket}/{prefix}/sales/sales.parquet\n") + + +if __name__ == "__main__": + main() diff --git a/02-use-cases/text-to-sql-data-analyst/src/__init__.py b/02-use-cases/text-to-sql-data-analyst/src/__init__.py new file mode 100644 index 000000000..046b1077a --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/src/__init__.py @@ -0,0 +1,6 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Text-to-SQL with Amazon Bedrock AgentCore""" + +__version__ = "1.0.0" diff --git a/02-use-cases/text-to-sql-data-analyst/src/policy_validator.py b/02-use-cases/text-to-sql-data-analyst/src/policy_validator.py new file mode 100644 index 000000000..d1e5d7fe6 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/src/policy_validator.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Policy Validator β€” SQL Security Validation + +Ensures only SELECT queries are executed. +Validates against dangerous commands and auto-applies LIMIT. +""" + +from dataclasses import dataclass +from typing import Optional + +import sqlparse +from sqlparse.sql import Statement + + +@dataclass +class PolicyValidationResult: + """SQL validation result.""" + + valid: bool + reason: Optional[str] = None + modified_sql: Optional[str] = None + + +class PolicyValidator: + """ + Validates SQL queries against security policies. + + Ensures that: + - Only SELECT commands are allowed + - No dangerous DDL/DML keywords are present + - LIMIT is added if missing + """ + + DANGEROUS_KEYWORDS = [ + "DROP", "DELETE", "INSERT", "UPDATE", "ALTER", + "CREATE", "TRUNCATE", "GRANT", "REVOKE", "EXEC", "EXECUTE", + ] + + def __init__(self, default_limit: int = 1000): + self.default_limit = default_limit + + def validate(self, sql: str) -> PolicyValidationResult: + """Validate a SQL query against security policies.""" + if not sql or not sql.strip(): + return PolicyValidationResult(valid=False, reason="SQL query is empty") + + try: + parsed = sqlparse.parse(sql) + except Exception as e: + return PolicyValidationResult( + valid=False, reason=f"SQL could not be parsed: {e}" + ) + + if not parsed: + return PolicyValidationResult( + valid=False, reason="SQL could not be parsed" + ) + + for statement in parsed: + if not self._is_select_statement(statement): + stmt_type = statement.get_type() + return PolicyValidationResult( + valid=False, + reason=f"Command {stmt_type} not allowed. Only SELECT is valid.", + ) + + sql_upper = sql.upper() + for keyword in self.DANGEROUS_KEYWORDS: + if keyword in sql_upper: + return PolicyValidationResult( + valid=False, + reason=f"Dangerous keyword detected: {keyword}", + ) + + modified_sql = self._ensure_limit(sql) + return PolicyValidationResult(valid=True, modified_sql=modified_sql) + + def _is_select_statement(self, statement: Statement) -> bool: + return statement.get_type() == "SELECT" + + def _ensure_limit(self, sql: str) -> str: + if "LIMIT" in sql.upper(): + return sql + sql_stripped = sql.rstrip().rstrip(";") + return f"{sql_stripped} LIMIT {self.default_limit}" diff --git a/02-use-cases/text-to-sql-data-analyst/src/tools/__init__.py b/02-use-cases/text-to-sql-data-analyst/src/tools/__init__.py new file mode 100644 index 000000000..78f71f157 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/src/tools/__init__.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tools for the Text-to-SQL agent.""" + +from .discover_schema import ( + discover_schema, + SchemaInfo, + TableInfo, + ColumnInfo, + SchemaDiscoveryError, +) + +__all__ = [ + "discover_schema", + "SchemaInfo", + "TableInfo", + "ColumnInfo", + "SchemaDiscoveryError", +] diff --git a/02-use-cases/text-to-sql-data-analyst/src/tools/discover_schema.py b/02-use-cases/text-to-sql-data-analyst/src/tools/discover_schema.py new file mode 100644 index 000000000..d9edb89f7 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/src/tools/discover_schema.py @@ -0,0 +1,167 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Schema Discovery Tool + +Discovers tables and columns from AWS Glue Data Catalog +based on keyword search. +""" + +from dataclasses import dataclass +from typing import List, Optional + +import boto3 +from botocore.exceptions import ClientError + + +@dataclass +class ColumnInfo: + name: str + type: str + description: str + + +@dataclass +class TableInfo: + name: str + description: str + columns: List[ColumnInfo] + relationships: Optional[List[str]] = None + + +@dataclass +class SchemaInfo: + tables: List[TableInfo] + total_tables_in_catalog: int = 0 + + +class SchemaDiscoveryError(Exception): + pass + + +def discover_schema( + keywords: List[str], + database_name: str, + aws_region: str = "us-east-1", + max_tables: int = 5, +) -> SchemaInfo: + """ + Discover relevant tables in Glue Data Catalog based on keywords. + + Args: + keywords: List of keywords (e.g., ["sales", "customer"]) + database_name: Glue database name + aws_region: AWS region + max_tables: Maximum tables to return + """ + if not keywords: + raise SchemaDiscoveryError("Keywords list cannot be empty") + if not database_name: + raise SchemaDiscoveryError("Database name cannot be empty") + + try: + glue_client = boto3.client("glue", region_name=aws_region) + except Exception as e: + raise SchemaDiscoveryError(f"Failed to create Glue client: {e}") + + try: + relevant_tables = _search_tables_by_keywords( + glue_client, database_name, keywords + ) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + if error_code == "EntityNotFoundException": + raise SchemaDiscoveryError(f"Database '{database_name}' not found") + raise SchemaDiscoveryError(f"Failed to search tables: {e}") + + relevant_tables = relevant_tables[:max_tables] + schema_info = SchemaInfo(tables=[]) + + for table_name in relevant_tables: + try: + table_metadata = glue_client.get_table( + DatabaseName=database_name, Name=table_name + ) + table = table_metadata["Table"] + columns = [ + ColumnInfo( + name=col["Name"], + type=col["Type"], + description=col.get("Comment", ""), + ) + for col in table.get("StorageDescriptor", {}).get("Columns", []) + ] + relationships = _extract_relationships(table.get("Description", "")) + schema_info.tables.append( + TableInfo( + name=table_name, + description=table.get("Description", ""), + columns=columns, + relationships=relationships if relationships else None, + ) + ) + except ClientError: + continue + + try: + all_tables = glue_client.get_tables(DatabaseName=database_name) + schema_info.total_tables_in_catalog = len(all_tables.get("TableList", [])) + except ClientError: + pass + + return schema_info + + +def _search_tables_by_keywords( + glue_client, database_name: str, keywords: List[str] +) -> List[str]: + all_tables = [] + next_token = None + while True: + kwargs = {"DatabaseName": database_name} + if next_token: + kwargs["NextToken"] = next_token + response = glue_client.get_tables(**kwargs) + all_tables.extend(response.get("TableList", [])) + next_token = response.get("NextToken") + if not next_token: + break + + keywords_lower = [kw.lower() for kw in keywords] + scored_tables = [] + + for table in all_tables: + table_name = table["Name"] + table_desc = table.get("Description", "") + score = 0 + + for kw in keywords_lower: + if kw in table_name.lower(): + score += 10 + if kw in table_desc.lower(): + score += 5 + for col in table.get("StorageDescriptor", {}).get("Columns", []): + if kw in col["Name"].lower(): + score += 3 + if kw in col.get("Comment", "").lower(): + score += 1 + + if score > 0: + scored_tables.append((table_name, score)) + + scored_tables.sort(key=lambda x: x[1], reverse=True) + return [name for name, _ in scored_tables] + + +def _extract_relationships(description: str) -> List[str]: + if not description: + return [] + relationships = [] + indicators = ["join", "related", "relationship", "foreign key", "references", "links"] + for indicator in indicators: + if indicator in description.lower(): + for sentence in description.split("."): + if indicator in sentence.lower(): + relationships.append(sentence.strip()) + return relationships diff --git a/02-use-cases/text-to-sql-data-analyst/src/tools/execute_query.py b/02-use-cases/text-to-sql-data-analyst/src/tools/execute_query.py new file mode 100644 index 000000000..da31def65 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/src/tools/execute_query.py @@ -0,0 +1,325 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Query Execution Tool + +Executes SQL queries on Amazon Athena or Amazon Redshift. +Applies security policies: timeouts and row limits. +""" + +from dataclasses import dataclass +from typing import List, Dict, Any, Literal +import time + +import boto3 +from botocore.exceptions import ClientError + + +@dataclass +class ColumnMetadata: + name: str + type: str + nullable: bool = True + + +@dataclass +class QueryResult: + columns: List[ColumnMetadata] + rows: List[Dict[str, Any]] + row_count: int + truncated: bool + execution_time: int # milliseconds + query_id: str = "" + data_scanned_bytes: int = 0 + + +class QueryExecutionError(Exception): + def __init__(self, message: str, user_message: str = None): + super().__init__(message) + self.user_message = user_message or message + + +class PolicyViolationError(Exception): + def __init__(self, reason: str): + super().__init__(reason) + self.reason = reason + + +def execute_query( + sql: str, + engine_type: Literal["athena", "redshift"], + database_name: str, + aws_region: str = "us-east-1", + timeout_seconds: int = 30, + max_rows: int = 1000, + athena_output_location: str = None, + athena_workgroup: str = "primary", + redshift_cluster_id: str = None, + redshift_db_user: str = None, +) -> QueryResult: + """Execute a SQL query on Athena or Redshift with security policies.""" + if not sql or not sql.strip(): + raise QueryExecutionError("SQL query is empty", "The SQL query is empty") + if not database_name: + raise QueryExecutionError("Database name is required") + if not _is_select_only(sql): + raise PolicyViolationError("Only SELECT commands are allowed") + + start_time = time.time() + + if engine_type == "athena": + if not athena_output_location: + raise ValueError("athena_output_location is required for Athena queries") + result = _execute_athena_query( + sql, database_name, athena_output_location, + athena_workgroup, aws_region, timeout_seconds, + ) + elif engine_type == "redshift": + if not redshift_cluster_id or not redshift_db_user: + raise ValueError( + "redshift_cluster_id and redshift_db_user are required" + ) + result = _execute_redshift_query( + sql, database_name, redshift_cluster_id, + redshift_db_user, aws_region, timeout_seconds, + ) + else: + raise ValueError(f"Unsupported engine type: {engine_type}") + + result.execution_time = int((time.time() - start_time) * 1000) + + if len(result.rows) > max_rows: + result.rows = result.rows[:max_rows] + result.row_count = max_rows + result.truncated = True + + return result + + +def _is_select_only(sql: str) -> bool: + sql_upper = sql.upper().strip() + if not (sql_upper.startswith("SELECT") or sql_upper.startswith("WITH")): + return False + dangerous = [ + "DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", "TRUNCATE", + ] + return not any(kw in sql_upper for kw in dangerous) + + +def _execute_athena_query( + sql, database_name, output_location, workgroup, aws_region, timeout_seconds +): + try: + athena = boto3.client("athena", region_name=aws_region) + except Exception as e: + raise QueryExecutionError( + f"Failed to create Athena client: {e}", + "Could not connect to Athena", + ) + + try: + response = athena.start_query_execution( + QueryString=sql, + QueryExecutionContext={"Database": database_name}, + ResultConfiguration={"OutputLocation": output_location}, + WorkGroup=workgroup, + ) + qid = response["QueryExecutionId"] + except ClientError as e: + msg = e.response.get("Error", {}).get("Message", str(e)) + raise QueryExecutionError( + f"Failed to start query: {msg}", + f"Error starting query: {msg}", + ) + + status = _wait_for_athena(athena, qid, timeout_seconds) + + if status == "FAILED": + try: + ex = athena.get_query_execution(QueryExecutionId=qid) + reason = ex["QueryExecution"]["Status"].get( + "StateChangeReason", "Unknown" + ) + except Exception: + reason = "Unknown" + raise QueryExecutionError( + f"Query failed: {reason}", f"Query failed: {reason}" + ) + elif status == "TIMEOUT": + raise QueryExecutionError( + "Query timeout", f"Timeout of {timeout_seconds}s exceeded" + ) + elif status != "SUCCEEDED": + raise QueryExecutionError(f"Unexpected status: {status}") + + data_scanned = 0 + try: + ex = athena.get_query_execution(QueryExecutionId=qid) + data_scanned = ( + ex.get("QueryExecution", {}) + .get("Statistics", {}) + .get("DataScannedInBytes", 0) + ) + except Exception: + pass + + try: + results = athena.get_query_results(QueryExecutionId=qid) + result = _parse_athena_results(results, qid) + result.data_scanned_bytes = data_scanned + return result + except ClientError as e: + msg = e.response.get("Error", {}).get("Message", str(e)) + raise QueryExecutionError( + f"Failed to get results: {msg}", + f"Error retrieving results: {msg}", + ) + + +def _wait_for_athena(athena, qid, timeout_seconds, poll=0.5): + start = time.time() + while True: + if time.time() - start > timeout_seconds: + try: + athena.stop_query_execution(QueryExecutionId=qid) + except Exception: + pass + return "TIMEOUT" + try: + resp = athena.get_query_execution(QueryExecutionId=qid) + status = resp["QueryExecution"]["Status"]["State"] + if status in ("SUCCEEDED", "FAILED", "CANCELLED"): + return status + except ClientError: + pass + time.sleep(poll) + + +def _parse_athena_results(results, query_id): + result_set = results.get("ResultSet", {}) + rows_data = result_set.get("Rows", []) + if not rows_data: + return QueryResult( + columns=[], rows=[], row_count=0, truncated=False, + execution_time=0, query_id=query_id, + ) + + column_info = result_set.get("ResultSetMetadata", {}).get("ColumnInfo", []) + columns = [ + ColumnMetadata( + name=c.get("Name", ""), + type=c.get("Type", "VARCHAR"), + nullable=c.get("Nullable", "NULLABLE") != "NOT_NULL", + ) + for c in column_info + ] + + rows = [] + for row_data in rows_data[1:]: + row_dict = {} + for i, col in enumerate(columns): + cell = ( + row_data.get("Data", [])[i] + if i < len(row_data.get("Data", [])) + else {} + ) + value = cell.get("VarCharValue") + row_dict[col.name] = ( + _convert_value(value, col.type) if value is not None else None + ) + rows.append(row_dict) + + return QueryResult( + columns=columns, rows=rows, row_count=len(rows), truncated=False, + execution_time=0, query_id=query_id, + ) + + +def _execute_redshift_query( + sql, database_name, cluster_id, db_user, aws_region, timeout_seconds +): + try: + client = boto3.client("redshift-data", region_name=aws_region) + except Exception as e: + raise QueryExecutionError(f"Failed to create Redshift client: {e}") + + try: + response = client.execute_statement( + ClusterIdentifier=cluster_id, + Database=database_name, + DbUser=db_user, + Sql=sql, + ) + sid = response["Id"] + except ClientError as e: + msg = e.response.get("Error", {}).get("Message", str(e)) + raise QueryExecutionError(f"Redshift query failed: {msg}") + + start = time.time() + while True: + if time.time() - start > timeout_seconds: + try: + client.cancel_statement(Id=sid) + except Exception: + pass + raise QueryExecutionError("Redshift query timeout") + try: + desc = client.describe_statement(Id=sid) + status = desc["Status"] + if status == "FINISHED": + break + elif status in ("FAILED", "ABORTED"): + raise QueryExecutionError( + f"Redshift query {status}: {desc.get('Error', 'Unknown')}" + ) + except ClientError: + pass + time.sleep(0.5) + + results = client.get_statement_result(Id=sid) + col_meta = results.get("ColumnMetadata", []) + columns = [ + ColumnMetadata(name=c.get("name", ""), type=c.get("typeName", "VARCHAR")) + for c in col_meta + ] + + rows = [] + for record in results.get("Records", []): + row = {} + for i, col in enumerate(columns): + if i < len(record): + field = record[i] + value = ( + field.get("stringValue") + or field.get("longValue") + or field.get("doubleValue") + or field.get("booleanValue") + ) + if field.get("isNull"): + value = None + row[col.name] = value + else: + row[col.name] = None + rows.append(row) + + return QueryResult( + columns=columns, rows=rows, row_count=len(rows), truncated=False, + execution_time=0, query_id=sid, + ) + + +def _convert_value(value, col_type): + if value is None or value == "": + return None + ct = col_type.upper() + try: + if ct in ("INTEGER", "INT", "BIGINT", "SMALLINT", "TINYINT"): + return int(value) + elif ct in ("DOUBLE", "FLOAT", "DECIMAL", "REAL"): + return float(value) + elif ct in ("BOOLEAN", "BOOL"): + return value.lower() in ("true", "1", "t", "yes") + except (ValueError, AttributeError): + pass + return value diff --git a/02-use-cases/text-to-sql-data-analyst/tests/__init__.py b/02-use-cases/text-to-sql-data-analyst/tests/__init__.py new file mode 100644 index 000000000..04f8b7b76 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/02-use-cases/text-to-sql-data-analyst/tests/test_policy_validator.py b/02-use-cases/text-to-sql-data-analyst/tests/test_policy_validator.py new file mode 100644 index 000000000..e299c5792 --- /dev/null +++ b/02-use-cases/text-to-sql-data-analyst/tests/test_policy_validator.py @@ -0,0 +1,102 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for PolicyValidator.""" + +import pytest +from src.policy_validator import PolicyValidator, PolicyValidationResult + + +@pytest.fixture +def validator(): + return PolicyValidator(default_limit=1000) + + +class TestPolicyValidator: + """Test SQL policy validation.""" + + def test_valid_select(self, validator): + result = validator.validate("SELECT * FROM customers") + assert result.valid is True + + def test_valid_select_with_where(self, validator): + result = validator.validate( + "SELECT name, email FROM customers WHERE segment = 'premium'" + ) + assert result.valid is True + + def test_valid_select_with_join(self, validator): + sql = """ + SELECT c.name, SUM(s.total_amount) as total + FROM sales s + JOIN customers c ON s.customer_id = c.customer_id + GROUP BY c.name + ORDER BY total DESC + """ + result = validator.validate(sql) + assert result.valid is True + + def test_valid_with_cte(self, validator): + sql = """ + WITH top_customers AS ( + SELECT customer_id, SUM(total_amount) as total + FROM sales GROUP BY customer_id + ) + SELECT c.name, t.total + FROM top_customers t + JOIN customers c ON t.customer_id = c.customer_id + """ + result = validator.validate(sql) + assert result.valid is True + + def test_rejects_drop(self, validator): + result = validator.validate("DROP TABLE customers") + assert result.valid is False + assert "DROP" in result.reason + + def test_rejects_delete(self, validator): + result = validator.validate("DELETE FROM customers WHERE id = 1") + assert result.valid is False + + def test_rejects_insert(self, validator): + result = validator.validate( + "INSERT INTO customers (name) VALUES ('test')" + ) + assert result.valid is False + + def test_rejects_update(self, validator): + result = validator.validate( + "UPDATE customers SET name = 'test' WHERE id = 1" + ) + assert result.valid is False + + def test_rejects_truncate(self, validator): + result = validator.validate("TRUNCATE TABLE customers") + assert result.valid is False + + def test_rejects_create(self, validator): + result = validator.validate("CREATE TABLE test (id INT)") + assert result.valid is False + + def test_rejects_empty(self, validator): + result = validator.validate("") + assert result.valid is False + + def test_rejects_none(self, validator): + result = validator.validate(None) + assert result.valid is False + + def test_adds_limit_when_missing(self, validator): + result = validator.validate("SELECT * FROM customers") + assert result.valid is True + assert "LIMIT 1000" in result.modified_sql + + def test_preserves_existing_limit(self, validator): + result = validator.validate("SELECT * FROM customers LIMIT 10") + assert result.valid is True + assert result.modified_sql == "SELECT * FROM customers LIMIT 10" + + def test_custom_default_limit(self): + v = PolicyValidator(default_limit=500) + result = v.validate("SELECT * FROM customers") + assert "LIMIT 500" in result.modified_sql