diff --git a/.env.local.example b/.env.local.example index ad8b38c..20aa8e3 100644 --- a/.env.local.example +++ b/.env.local.example @@ -1,12 +1,35 @@ # Databricks Configuration +# Get these values from your Databricks workspace settings + +# Your Databricks workspace hostname (without https://) DATABRICKS_SERVER_HOSTNAME=your-workspace.cloud.databricks.com + +# Your Databricks Personal Access Token +# Generate from: User Settings > Developer > Access tokens +DATABRICKS_PAT_TOKEN=dapi123...your-databricks-pat-token + +# Your SQL Warehouse HTTP Path +# Get from: SQL Warehouses > Select warehouse > Connection details DATABRICKS_HTTP_PATH=/sql/1.0/warehouses/your-warehouse-id -DATABRICKS_PAT_TOKEN=dapi123...your-pat-token -DATABRICKS_METASTORE_REGION=us-west-1 # Skyflow Configuration +# Get these values from your Skyflow vault + +# Your Skyflow vault URL SKYFLOW_VAULT_URL=https://your-vault.vault.skyflowapis.com + +# Your Skyflow Personal Access Token +# Generate from: Skyflow Studio > Settings > Tokens +SKYFLOW_PAT_TOKEN=eyJhbGci...your-skyflow-pat-token + +# Your Skyflow vault ID SKYFLOW_VAULT_ID=your-vault-id -SKYFLOW_PAT_TOKEN=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9...your-pat-token -SKYFLOW_TABLE=pii -SKYFLOW_BATCH_SIZE=25 \ No newline at end of file + +# Your Skyflow table name (where PII will be stored) +SKYFLOW_TABLE=customer_data + +# Optional: Group Mappings for Role-Based Access Control +# These control which user groups can see real vs tokenized data +PLAIN_TEXT_GROUPS=auditor # Groups that see real data +MASKED_GROUPS=customer_service # Groups that see masked data +REDACTED_GROUPS=marketing # Groups that see redacted data diff --git a/.gitignore b/.gitignore index 5cc2148..48d1838 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,47 @@ -# Environment files with credentials +# Environment files .env.local +.env -# OS generated files +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual environments +venv/ +env/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS .DS_Store .DS_Store? ._* @@ -10,17 +50,24 @@ ehthumbs.db Thumbs.db -# IDE files -.vscode/ -.idea/ -*.swp -*.swo -*~ - # Logs *.log logs/ # Temporary files -tmp/ -temp/ \ No newline at end of file +*.tmp +*.temp +temp/ + +# Databricks +.databricks/ +databricks.yml + +# Test outputs +test_results/ +coverage/ + +# Runtime files +*.pid +*.seed +*.pid.lock \ No newline at end of file diff --git a/README.md b/README.md index 5a4eca0..2ad38a5 100644 --- a/README.md +++ b/README.md @@ -2,38 +2,22 @@ This solution provides secure data tokenization and detokenization capabilities in Databricks Unity Catalog to protect PII and other sensitive data using Skyflow's Data Privacy Vault services. Built with pure SQL UDFs using Unity Catalog HTTP connections for maximum performance and seamless integration with column-level security. -## Table of Contents - -- [Quick Start](#quick-start) -- [Key Benefits](#key-benefits) -- [Architecture](#architecture) -- [Flow Diagrams](#flow-diagrams) - - [Tokenization Flow](#tokenization-flow) - - [Detokenization Flow](#detokenization-flow) -- [Features](#features) -- [Prerequisites](#prerequisites) -- [Usage Examples](#usage-examples) -- [Project Structure](#project-structure) -- [Configuration](#configuration) -- [Development Guide](#development-guide) -- [Cleanup](#cleanup) -- [Dashboard Integration](#dashboard-integration) -- [Support](#support) -- [License](#license) +## Demo + +A demonstration of this solution was featured in the 'From PII to GenAI: Architecting for Data Privacy & Security in 2025' webinar. + +[![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/x2-wVW04njw/0.jpg)](https://www.youtube.com/watch?v=x2-wVW04njw&t=623s) ## Quick Start 1. **Clone and Configure**: - ```bash git clone cd databricks-skyflow-integration cp .env.local.example .env.local ``` -2. **Set Environment Variables**: - - Edit `.env.local` with your credentials: +2. **Set Environment Variables** - Edit `.env.local` with your credentials: ```bash # Databricks Configuration DATABRICKS_SERVER_HOSTNAME=your-workspace.cloud.databricks.com @@ -46,87 +30,68 @@ This solution provides secure data tokenization and detokenization capabilities SKYFLOW_VAULT_ID=your-vault-id ``` -3. **Deploy Everything**: - +3. **Install Dependencies & Deploy**: ```bash - ./setup.sh create demo + pip install -r requirements.txt + python setup.py create demo ``` This creates: - - ✅ Unity Catalog: `demo_catalog` - - ✅ Sample table with 25,000 tokenized records + - ✅ Unity Catalog: `demo_catalog` + - ✅ Sample table with tokenized records - ✅ UC connections and secrets - ✅ Pure SQL detokenization functions - - ✅ Column masks (first_name only) + - ✅ Column masks on 6 PII columns with role-based access - ✅ Customer insights dashboard -4. **Test Access**: - +4. **Test Role-Based Access**: ```sql - -- Auditor group members see real names - -- Others see tokens - SELECT first_name FROM demo_catalog.default.demo_customer_data; + -- Query returns different results based on your role + SELECT first_name, last_name, email, phone_number FROM demo_catalog.default.demo_customer_data LIMIT 3; + + -- Check your group membership + SELECT + current_user() AS user, + is_account_group_member('auditor') AS is_auditor, + is_account_group_member('customer_service') AS is_customer_service, + is_account_group_member('marketing') AS is_marketing; ``` -## Demo - -A demonstration of this solution was featured in the 'From PII to GenAI: Architecting for Data Privacy & Security in 2025' webinar. - -[![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/x2-wVW04njw/0.jpg)](https://www.youtube.com/watch?v=x2-wVW04njw&t=623s) - ## Key Benefits - **🚀 Pure SQL Performance**: Unity Catalog HTTP connections with zero Python overhead -- **🔒 Column-Level Security**: Automatic role-based data masking via Unity Catalog column masks -- **⚡ Serverless Optimized**: Designed for Databricks serverless compute environments -- **🎯 Simplified Architecture**: Single row-by-row processing - no complex batching required -- **🔧 Easy Integration**: Native Unity Catalog functions work with any SQL client (ODBC, JDBC, notebooks) +- **🔒 Role-Based Security**: Automatic data masking via Unity Catalog column masks +- **⚡ Serverless Ready**: Designed for Databricks serverless compute environments +- **🔧 Easy Integration**: Native Unity Catalog functions work with any SQL client - **📊 Real-time Access Control**: Instant role-based access via `is_account_group_member()` - **🛡️ Graceful Error Handling**: Returns tokens on API failures to ensure data availability -## Architecture +## Role-Based Data Access -The solution leverages Unity Catalog's native capabilities for maximum performance and security: +The solution supports three role-based access levels: -### Implementation Overview +| Role | Group | Data Visibility | Example Output | +|------|-------|----------------|----------------| +| **Auditor** | `auditor` | **Plain text** (detokenized) | `Jonathan` | +| **Customer Service** | `customer_service` | **Masked** (partial hiding) | `J****an` | +| **Marketing** | `marketing` | **Token** (no access) | `4532-8765-9abc-def0` | -1. **Unity Catalog HTTP Connections**: Direct API integration with bearer token authentication -2. **Pure SQL UDFs**: Zero Python overhead, native Spark SQL execution -3. **Column Masks**: Automatic role-based data protection at the table level -4. **UC Secrets**: Secure credential storage with `secret()` function references -5. **Account Groups**: Enterprise-grade role management via `is_account_group_member()` - -### Key Components +### Column Mask Behavior -- **Tokenization Connection**: `skyflow_tokenize_conn` → `/v1/vaults/{vault_id}/{table}` -- **Detokenization Connection**: `skyflow_detokenize_conn` → `/v1/vaults/{vault_id}/detokenize` -- **Role-based UDF**: `sam_skyflow_conditional_detokenize()` - only auditors see real data -- **Column Mask UDF**: `sam_skyflow_mask_detokenize()` - applied at table schema level +```sql +-- Same query, different results based on role: +SELECT customer_id, first_name, email FROM demo_catalog.default.demo_customer_data LIMIT 1; -## Flow Diagrams +-- Auditor sees: CUST00001 | Jonathan | jonathan.anderson@example.com +-- Customer Service: CUST00001 | J****an | j****an.a*****on@example.com +-- Marketing: CUST00001 | 4532-8765... | 9876-5432-abcd... +``` -### Tokenization Flow +**Role Propagation**: After changing user roles, Databricks may take 1-2 minutes to propagate changes. -```mermaid -sequenceDiagram - participant Setup as Setup Process - participant Notebook as Tokenization Notebook - participant SF as Skyflow API - participant UC as Unity Catalog - - Setup->>Notebook: Run serverless tokenization - Notebook->>UC: Get secrets via dbutils - - loop For each PII value - Notebook->>SF: POST /v1/vaults/{vault}/table - SF-->>Notebook: Return token - Notebook->>UC: UPDATE table SET column = token - end - - Notebook-->>Setup: Tokenization complete -``` +## Architecture -### Detokenization Flow +### Flow Overview ```mermaid sequenceDiagram @@ -135,155 +100,130 @@ sequenceDiagram participant UDF as Detokenize UDF participant SF as Skyflow API - Client->>UC: SELECT first_name FROM customer_data + Client->>UC: SELECT first_name, email FROM customer_data UC->>UC: Check column mask policy UC->>UDF: Call detokenization function - UDF->>UDF: Check is_account_group_member('auditor') + UDF->>UDF: Check is_account_group_member() alt User is auditor UDF->>SF: POST /detokenize via UC connection SF-->>UDF: Return plain text value UDF-->>UC: Return detokenized data - else User is not auditor + else User is customer_service + UDF-->>UC: Return masked data (no API call) + else User is other role UDF-->>UC: Return token (no API call) end UC-->>Client: Return appropriate data ``` -## Features - -### Data Protection - -- **Row-by-row processing**: Simple, reliable tokenization/detokenization -- **Column masks**: Automatic application at table schema level -- **Unity Catalog integration**: Native secrets and connections management -- **Role-based access**: Account group membership determines data visibility - -### Security - -- **Account-level groups**: Enterprise `is_account_group_member()` integration -- **UC-backed secrets**: Secure credential storage via `secret()` function -- **Bearer token authentication**: Automatic token injection via UC connections -- **Column-level security**: Masks applied at metadata level, not query level - -### Performance - -- **Pure SQL execution**: Zero Python UDF overhead -- **Native Spark SQL**: Full catalyst optimizer integration -- **Serverless optimized**: No cluster management required -- **Connection pooling**: UC manages HTTP connection lifecycle - -### Operational - -- **Organized SQL structure**: Clean separation of setup/destroy/verify operations -- **Graceful error handling**: API failures return tokens to maintain data access -- **ODBC/JDBC compatible**: Works with any SQL client - -## Prerequisites - -1. **Databricks Unity Catalog** with: - - Unity Catalog enabled workspace - - Account-level groups configured - - Serverless or cluster-based compute - - HTTP connections support - -2. **Skyflow Account** with: - - Valid PAT token - - Configured vault and table schema - - API access enabled - -## Usage Examples +### Key Components -### Basic Detokenization Query +- **Unity Catalog Connection**: `skyflow_conn` → `/v1/vaults/{vault_id}` (unified endpoint) +- **Role-based UDF**: `{prefix}_skyflow_conditional_detokenize()` - handles all role logic +- **Column Mask UDF**: `{prefix}_skyflow_mask_detokenize()` - applied at table schema level -```sql --- Works with any SQL client (ODBC, JDBC, notebooks, SQL editor) -SELECT - customer_id, - first_name, -- Detokenized for auditors, token for others - last_name, -- Plain text (no column mask) - email, -- Plain text (no column mask) - total_purchases -FROM demo_catalog.default.demo_customer_data -LIMIT 10; -``` +## Python CLI Usage -### Column Mask Behavior +```bash +# Create integration +python setup.py create demo -```sql --- Same query, different results based on role: +# Verify integration +python setup.py verify demo --- Auditor group member sees: --- customer_id | first_name | last_name | email --- CUST00001 | Jonathan | Anderson | jonathan.anderson@example.com +# Destroy integration +python setup.py destroy demo --- Non-auditor sees: --- customer_id | first_name | last_name | email --- CUST00001 | 4532-8765-9abc... | Anderson | jonathan.anderson@example.com +# Get help +python setup.py --help ``` -### Direct Function Calls +## Prerequisites -```sql --- Call detokenization function directly -SELECT demo_skyflow_uc_detokenize('4532-8765-9abc-def0') AS detokenized_value; +1. **Databricks Unity Catalog** with account-level groups configured +2. **Skyflow Account** with valid PAT token and configured vault + +## Databricks Permissions Required + +The user running this solution needs the following Databricks permissions: + +### **Account-Level Permissions** +- **Account Admin** OR **Metastore Admin** (to create catalogs and manage Unity Catalog resources) + +### **Workspace-Level Permissions** +| Permission | Purpose | Required For | +|------------|---------|--------------| +| **Create Cluster/SQL Warehouse** | Job execution | Tokenization notebook runs | +| **Manage Secrets** | Secret scope management | Creating `skyflow-secrets` scope | +| **Workspace Admin** | Resource management | Creating notebooks, dashboards | + +### **Unity Catalog Permissions** +| Resource | Permission | Purpose | +|----------|------------|---------| +| **Metastore** | `CREATE CATALOG` | Creating `{prefix}_catalog` | +| **Catalog** | `USE CATALOG`, `CREATE SCHEMA` | Schema and table creation | +| **Schema** | `USE SCHEMA`, `CREATE TABLE`, `CREATE FUNCTION` | Table and UDF creation | +| **External Locations** | `CREATE CONNECTION` | HTTP connections for Skyflow API | + +### **Required Account Groups** +The solution references these account-level groups (create before deployment): +- `auditor` - Users who see detokenized (plain text) data +- `customer_service` - Users who see masked data (e.g., `J****an`) +- `marketing` - Users who see only tokens + +### **PAT Token Permissions** +Your Databricks PAT token must have: +- **Workspace access** (read/write) +- **Unity Catalog access** (manage catalogs, schemas, functions) +- **SQL Warehouse access** (execute statements) +- **Secrets management** (create/manage secret scopes) + +### **Minimum Setup Command** +```bash +# Grant necessary permissions (run as Account Admin) +databricks account groups create --display-name "auditor" +databricks account groups create --display-name "customer_service" +databricks account groups create --display-name "marketing" --- Conditional detokenization (respects role) -SELECT demo_skyflow_conditional_detokenize('4532-8765-9abc-def0') AS role_based_value; +# Add users to appropriate groups +databricks account groups add-member --group-name "auditor" --user-name "user@company.com" ``` -### Role Propagation and Demo Testing - -**Important for demos**: After changing user roles or group membership, Databricks may take 1-2 minutes to propagate the changes. If role-based redaction isn't working as expected, check your current group membership: - -```sql --- Check your current user and group membership -SELECT - current_user() AS user, - is_account_group_member('auditor') AS is_auditor, - is_account_group_member('customer_service') AS is_customer_service, - is_account_group_member('marketing') AS is_marketing; - --- Alternative check using is_member() for workspace groups -SELECT - current_user() AS user, - is_member('auditor') AS is_auditor, - is_member('customer_service') AS is_customer_service, - is_member('marketing') AS is_marketing; +### **Permission Validation** +Test your permissions before deployment: +```bash +# Test configuration and permissions +python setup.py config-test + +# This validates: +# - PAT token authentication +# - Unity Catalog access +# - SQL Warehouse connectivity +# - Required file permissions ``` -If you recently changed roles and the detokenization isn't reflecting the new permissions, wait 1-2 minutes and re-run the query. The functions use both `is_account_group_member()` (for account-level groups) and `is_member()` (for workspace-level groups) to maximize compatibility. +### **Common Permission Issues** +| Error | Cause | Solution | +|-------|--------|----------| +| `PERMISSION_DENIED: User does not have CREATE CATALOG` | Missing metastore admin rights | Grant `Metastore Admin` or `Account Admin` | +| `INVALID_STATE: Cannot create secret scope` | Missing secrets permissions | Grant `Manage` permission on workspace secrets | +| `PERMISSION_DENIED: CREATE CONNECTION` | Missing connection permissions | Ensure Unity Catalog `CREATE CONNECTION` permission | +| `Group 'auditor' not found` | Missing account groups | Create account-level groups first | ## Project Structure ```text -. -├── README.md # This file -├── .env.local.example # Environment configuration template -├── config.sh # Configuration loader script -├── setup.sh # Main deployment script -├── sql/ # Organized SQL definitions -│ ├── setup/ # Setup-related SQL files -│ │ ├── create_catalog.sql -│ │ ├── create_sample_table.sql -│ │ ├── create_uc_connections.sql -│ │ ├── setup_uc_connections_api.sql -│ │ └── apply_column_masks.sql -│ ├── destroy/ # Cleanup SQL files -│ │ ├── cleanup_catalog.sql -│ │ ├── drop_functions.sql -│ │ ├── drop_table.sql -│ │ └── remove_column_masks.sql -│ └── verify/ # Verification SQL files -│ ├── verify_functions.sql -│ ├── verify_table.sql -│ ├── check_functions_exist.sql -│ └── check_table_exists.sql -├── notebooks/ # Serverless tokenization -│ └── notebook_tokenize_table.ipynb -└── dashboards/ # Pre-built analytics - └── customer_insights_dashboard.lvdash.json +skyflow_databricks/ # Main Python package +├── cli/ # CLI commands +├── config/ # Configuration management +├── databricks_ops/ # Databricks SDK operations +├── utils/ # Utility functions +└── templates/ # Deployment templates + ├── sql/ # SQL definitions (setup/destroy/verify) + ├── notebooks/ # Serverless tokenization notebook + └── dashboards/ # Pre-built analytics dashboard ``` ## Configuration @@ -292,20 +232,22 @@ If you recently changed roles and the detokenization isn't reflecting the new pe ```bash # Databricks Connection -DATABRICKS_HOST=https://your-workspace.cloud.databricks.com -DATABRICKS_TOKEN=dapi123... -WAREHOUSE_ID=abc123... +DATABRICKS_SERVER_HOSTNAME=your-workspace.cloud.databricks.com +DATABRICKS_PAT_TOKEN=dapi123... +DATABRICKS_HTTP_PATH=/sql/1.0/warehouses/your-warehouse-id # Skyflow Integration -SKYFLOW_VAULT_URL=https://vault.skyflow.com -SKYFLOW_VAULT_ID=abc123... -SKYFLOW_PAT_TOKEN=sky123... -SKYFLOW_TABLE=customer_data - -# Role Mappings (optional) +SKYFLOW_VAULT_URL=https://your-vault.vault.skyflowapis.com +SKYFLOW_VAULT_ID=your-vault-id +SKYFLOW_PAT_TOKEN=eyJhbGci...your-pat-token +SKYFLOW_TABLE=pii +SKYFLOW_TABLE_COLUMN=pii_values +SKYFLOW_BATCH_SIZE=25 + +# Role Mappings (used by functions) PLAIN_TEXT_GROUPS=auditor # See real data -MASKED_GROUPS=customer_service # See masked data -REDACTED_GROUPS=marketing # See redacted data +MASKED_GROUPS=customer_service # See masked data (e.g., J****an) +REDACTED_GROUPS=marketing # See tokens only ``` ### Unity Catalog Setup @@ -313,7 +255,7 @@ REDACTED_GROUPS=marketing # See redacted data The solution creates these UC resources: - **Secrets Scope**: `skyflow-secrets` (UC-backed) -- **HTTP Connections**: `skyflow_tokenize_conn`, `skyflow_detokenize_conn` +- **HTTP Connection**: `skyflow_conn` (unified connection) - **Catalog**: `{prefix}_catalog` with default schema - **Functions**: Pure SQL UDFs for tokenization/detokenization - **Column Masks**: Applied to sensitive columns only @@ -322,83 +264,33 @@ The solution creates these UC resources: ### Adding New PII Columns -1. **Update tokenization**: +1. **Update tokenization**: Edit `skyflow_databricks/databricks_ops/notebooks.py` to include new columns in `pii_columns` +2. **Add column masks**: Edit `skyflow_databricks/templates/sql/setup/apply_column_masks.sql` +3. **Redeploy**: `python setup.py recreate demo` - ```bash - # Edit setup.sh line ~726 - local pii_columns="first_name,last_name,email" - ``` - -2. **Add column masks**: +### CLI Features - ```sql - -- Edit sql/setup/apply_column_masks.sql - ALTER TABLE ${PREFIX}_customer_data ALTER COLUMN email SET MASK ${PREFIX}_skyflow_mask_detokenize; - ``` +- **Databricks SDK Integration**: Uses official SDK methods instead of raw API calls +- **Better Error Handling**: Detailed error messages and automatic retry logic +- **Progress Indicators**: Visual progress bars for long-running operations +- **Rich Output**: Colored, formatted output for better readability -3. **Redeploy**: - - ```bash - ./setup.sh recreate demo - ``` - -### Testing Changes - -```bash -# Test individual SQL components -python3 -c " -import os -from setup import execute_sql -execute_sql('sql/verify/verify_functions.sql') -" - -# Full integration test -./setup.sh recreate test -``` - -### Performance Optimization Ideas +## Dashboard Integration -For high-volume scenarios, consider: +The included dashboard demonstrates real-time role-based data access with customer insights, purchase patterns, and consent tracking. The dashboard URL is provided after setup completion. -- **Bulk processing**: 25-token API batches (requires complex result mapping) -- **Connection pooling**: Multiple UC connections for load distribution -- **Caching layer**: Frequently-accessed token caching with TTL -- **Async processing**: Queue-based bulk operations +![databricks_dashboard](https://github.com/user-attachments/assets/f81227c5-fbbf-481c-b7dc-516f64ad6114) ## Cleanup -Remove all resources: - ```bash -./setup.sh destroy demo +python setup.py destroy demo ``` -This removes: - -- Catalog and all objects -- UC connections and secrets -- Functions and column masks -- Notebooks and dashboards - -## Dashboard Integration - -The included dashboard demonstrates real-time role-based data access: - -![databricks_dashboard](https://github.com/user-attachments/assets/f81227c5-fbbf-481c-b7dc-516f64ad6114) - -**Features:** - -- **Customer Overview**: Shows first_name detokenization based on user role -- **Analytics**: Purchase patterns, loyalty analysis, consent tracking -- **Real-time**: Updates automatically as data changes -- **Role-aware**: Same dashboard, different data visibility per user - -**Access**: Dashboard URL provided after setup completion +Removes all UC resources: catalog, connections, secrets, functions, column masks, notebooks, and dashboards. ## Support -For issues and feature requests: - - **Skyflow Documentation**: [docs.skyflow.com](https://docs.skyflow.com) - **Databricks Unity Catalog**: [docs.databricks.com/unity-catalog](https://docs.databricks.com/unity-catalog/) - **GitHub Issues**: Please use the repository issue tracker diff --git a/config.sh b/config.sh deleted file mode 100644 index 14a40f5..0000000 --- a/config.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash - -# Load .env.local if it exists -if [[ -f "$(dirname "$0")/.env.local" ]]; then - echo "Loading configuration from .env.local..." - export $(grep -v '^#' "$(dirname "$0")/.env.local" | xargs) -fi - -# Map .env.local variables to our config format (no hardcoded defaults) -if [[ -n "$DATABRICKS_SERVER_HOSTNAME" ]]; then - DEFAULT_DATABRICKS_HOST="https://$DATABRICKS_SERVER_HOSTNAME" -else - DEFAULT_DATABRICKS_HOST="" -fi - -DEFAULT_DATABRICKS_TOKEN="$DATABRICKS_PAT_TOKEN" -# Extract warehouse ID from HTTP path (format: /sql/1.0/warehouses/warehouse-id) -if [[ -n "$DATABRICKS_HTTP_PATH" ]]; then - DEFAULT_WAREHOUSE_ID=$(echo "$DATABRICKS_HTTP_PATH" | sed 's/.*warehouses\///') -else - DEFAULT_WAREHOUSE_ID="" -fi - -# Skyflow settings from .env.local (no hardcoded defaults) -DEFAULT_SKYFLOW_VAULT_URL="$SKYFLOW_VAULT_URL" -DEFAULT_SKYFLOW_VAULT_ID="$SKYFLOW_VAULT_ID" -DEFAULT_SKYFLOW_PAT_TOKEN="$SKYFLOW_PAT_TOKEN" -DEFAULT_SKYFLOW_TABLE="$SKYFLOW_TABLE" - -# Group mappings for detokenization -DEFAULT_PLAIN_TEXT_GROUPS="auditor" -DEFAULT_MASKED_GROUPS="customer_service" -DEFAULT_REDACTED_GROUPS="marketing" - -# Apply any provided values, otherwise use defaults -export DATABRICKS_HOST=${DATABRICKS_HOST:-$DEFAULT_DATABRICKS_HOST} -export DATABRICKS_TOKEN=${DATABRICKS_TOKEN:-$DEFAULT_DATABRICKS_TOKEN} -export WAREHOUSE_ID=${WAREHOUSE_ID:-$DEFAULT_WAREHOUSE_ID} - -export SKYFLOW_VAULT_URL=${SKYFLOW_VAULT_URL:-$DEFAULT_SKYFLOW_VAULT_URL} -export SKYFLOW_VAULT_ID=${SKYFLOW_VAULT_ID:-$DEFAULT_SKYFLOW_VAULT_ID} -export SKYFLOW_PAT_TOKEN=${SKYFLOW_PAT_TOKEN:-$DEFAULT_SKYFLOW_PAT_TOKEN} -export SKYFLOW_TABLE=${SKYFLOW_TABLE:-$DEFAULT_SKYFLOW_TABLE} - -export PLAIN_TEXT_GROUPS=${PLAIN_TEXT_GROUPS:-$DEFAULT_PLAIN_TEXT_GROUPS} -export MASKED_GROUPS=${MASKED_GROUPS:-$DEFAULT_MASKED_GROUPS} -export REDACTED_GROUPS=${REDACTED_GROUPS:-$DEFAULT_REDACTED_GROUPS} diff --git a/notebooks/notebook_tokenize_table.ipynb b/notebooks/notebook_tokenize_table.ipynb deleted file mode 100644 index 8a1eda7..0000000 --- a/notebooks/notebook_tokenize_table.ipynb +++ /dev/null @@ -1,290 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "cell-0", - "metadata": {}, - "outputs": [], - "source": [ - "# Unity Catalog-aware serverless tokenization notebook \n", - "# Uses dbutils.secrets.get() + UC HTTP connections for serverless compatibility\n", - "import json\n", - "import os\n", - "from pyspark.sql import SparkSession\n", - "from pyspark.dbutils import DBUtils\n", - "\n", - "# Initialize Spark session optimized for serverless compute\n", - "spark = SparkSession.builder \\\n", - " .appName(\"SkyflowTokenization\") \\\n", - " .config(\"spark.databricks.cluster.profile\", \"serverless\") \\\n", - " .config(\"spark.databricks.delta.autoCompact.enabled\", \"true\") \\\n", - " .config(\"spark.sql.adaptive.enabled\", \"true\") \\\n", - " .config(\"spark.sql.adaptive.coalescePartitions.enabled\", \"true\") \\\n", - " .getOrCreate()\n", - " \n", - "dbutils = DBUtils(spark)\n", - "\n", - "print(f\"✓ Running on Databricks serverless compute\")\n", - "print(f\"✓ Spark version: {spark.version}\")\n", - "\n", - "# Performance configuration\n", - "MAX_MERGE_BATCH_SIZE = 10000 # Maximum records per MERGE operation\n", - "COLLECT_BATCH_SIZE = 1000 # Maximum records to collect() from Spark at once\n", - "\n", - "# Define widgets to receive input parameters\n", - "dbutils.widgets.text(\"table_name\", \"\")\n", - "dbutils.widgets.text(\"pii_columns\", \"\")\n", - "dbutils.widgets.text(\"batch_size\", \"\") # Skyflow API batch size\n", - "\n", - "# Read widget values\n", - "table_name = dbutils.widgets.get(\"table_name\")\n", - "pii_columns = dbutils.widgets.get(\"pii_columns\").split(\",\")\n", - "SKYFLOW_BATCH_SIZE = int(dbutils.widgets.get(\"batch_size\"))\n", - "\n", - "if not table_name or not pii_columns:\n", - " raise ValueError(\"Both 'table_name' and 'pii_columns' must be provided.\")\n", - "\n", - "print(f\"Tokenizing table: {table_name}\")\n", - "print(f\"PII columns: {', '.join(pii_columns)}\")\n", - "print(f\"Skyflow API batch size: {SKYFLOW_BATCH_SIZE}\")\n", - "print(f\"MERGE batch size limit: {MAX_MERGE_BATCH_SIZE:,} records\")\n", - "print(f\"Collect batch size: {COLLECT_BATCH_SIZE:,} records\")\n", - "\n", - "# Extract catalog and schema from table name if fully qualified\n", - "if '.' in table_name:\n", - " parts = table_name.split('.')\n", - " if len(parts) == 3: # catalog.schema.table\n", - " catalog_name = parts[0]\n", - " schema_name = parts[1]\n", - " table_name_only = parts[2]\n", - " \n", - " # Set the catalog and schema context for this session\n", - " print(f\"Setting catalog context to: {catalog_name}\")\n", - " spark.sql(f\"USE CATALOG {catalog_name}\")\n", - " spark.sql(f\"USE SCHEMA {schema_name}\")\n", - " \n", - " # Use the simple table name for queries since context is set\n", - " table_name = table_name_only\n", - " print(f\"✓ Catalog context set, using table name: {table_name}\")\n", - "\n", - "print(\"✓ Using dbutils.secrets.get() + UC HTTP connections for serverless compatibility\")\n", - "\n", - "def tokenize_column_values(column_name, values):\n", - " \"\"\"\n", - " Tokenize a list of PII values using Unity Catalog HTTP connection.\n", - " Uses dbutils.secrets.get() and http_request() for serverless compatibility.\n", - " Returns list of tokens in same order as input values.\n", - " \"\"\"\n", - " if not values:\n", - " return []\n", - " \n", - " # Get secrets using dbutils (works in serverless)\n", - " table_column = dbutils.secrets.get(\"skyflow-secrets\", \"skyflow_table_column\")\n", - " vault_id = dbutils.secrets.get(\"skyflow-secrets\", \"skyflow_vault_id\")\n", - " skyflow_table = dbutils.secrets.get(\"skyflow-secrets\", \"skyflow_table\")\n", - " \n", - " # Create records for each value\n", - " skyflow_records = [{\n", - " \"fields\": {table_column: str(value)}\n", - " } for value in values if value is not None]\n", - "\n", - " # Create Skyflow tokenization payload\n", - " payload = {\n", - " \"records\": skyflow_records,\n", - " \"tokenization\": True\n", - " }\n", - "\n", - " print(f\" Tokenizing {len(skyflow_records)} values for {column_name}\")\n", - " \n", - " # Use Unity Catalog HTTP connection via SQL http_request function\n", - " json_payload = json.dumps(payload).replace(\"'\", \"''\")\n", - " tokenize_path = f\"{vault_id}/{skyflow_table}\"\n", - " \n", - " # Execute tokenization via consolidated UC connection\n", - " result_df = spark.sql(f\"\"\"\n", - " SELECT http_request(\n", - " conn => 'skyflow_conn',\n", - " method => 'POST',\n", - " path => '{tokenize_path}',\n", - " headers => map(\n", - " 'Content-Type', 'application/json',\n", - " 'Accept', 'application/json'\n", - " ),\n", - " json => '{json_payload}'\n", - " ) as full_response\n", - " \"\"\")\n", - " \n", - " # Parse response\n", - " full_response = result_df.collect()[0]['full_response']\n", - " result = json.loads(full_response.text)\n", - " \n", - " # Fail fast if API response indicates error\n", - " if \"error\" in result:\n", - " raise RuntimeError(f\"Skyflow API error: {result['error']}\")\n", - " \n", - " if \"records\" not in result:\n", - " raise RuntimeError(f\"Invalid Skyflow API response - missing 'records': {result}\")\n", - " \n", - " # Extract tokens in order\n", - " tokens = []\n", - " for i, record in enumerate(result.get(\"records\", [])):\n", - " if \"tokens\" in record and table_column in record[\"tokens\"]:\n", - " token = record[\"tokens\"][table_column]\n", - " tokens.append(token)\n", - " else:\n", - " raise RuntimeError(f\"Tokenization failed for value {i+1} in {column_name}. Record: {record}\")\n", - " \n", - " successful_tokens = len([t for i, t in enumerate(tokens) if t and str(t) != str(values[i])])\n", - " print(f\" Successfully tokenized {successful_tokens}/{len(values)} values\")\n", - " \n", - " return tokens\n", - "\n", - "def perform_chunked_merge(table_name, column, update_data):\n", - " \"\"\"\n", - " Perform MERGE operations in chunks to avoid memory/timeout issues.\n", - " Returns total number of rows updated.\n", - " \"\"\"\n", - " if not update_data:\n", - " return 0\n", - " \n", - " total_updated = 0\n", - " chunk_size = MAX_MERGE_BATCH_SIZE\n", - " total_chunks = (len(update_data) + chunk_size - 1) // chunk_size\n", - " \n", - " print(f\" Splitting {len(update_data):,} updates into {total_chunks} MERGE operations (max {chunk_size:,} per chunk)\")\n", - " \n", - " for chunk_idx in range(0, len(update_data), chunk_size):\n", - " chunk_data = update_data[chunk_idx:chunk_idx + chunk_size]\n", - " chunk_num = (chunk_idx // chunk_size) + 1\n", - " \n", - " # Create temporary view for this chunk\n", - " temp_df = spark.createDataFrame(chunk_data, [\"customer_id\", f\"new_{column}\"])\n", - " temp_view_name = f\"temp_{column}_chunk_{chunk_num}_{hash(column) % 1000}\"\n", - " temp_df.createOrReplaceTempView(temp_view_name)\n", - " \n", - " # Perform MERGE operation for this chunk\n", - " merge_sql = f\"\"\"\n", - " MERGE INTO `{table_name}` AS target\n", - " USING {temp_view_name} AS source\n", - " ON target.customer_id = source.customer_id\n", - " WHEN MATCHED THEN \n", - " UPDATE SET `{column}` = source.new_{column}\n", - " \"\"\"\n", - " \n", - " spark.sql(merge_sql)\n", - " chunk_updated = len(chunk_data)\n", - " total_updated += chunk_updated\n", - " \n", - " print(f\" Chunk {chunk_num}/{total_chunks}: Updated {chunk_updated:,} rows\")\n", - " \n", - " # Clean up temp view\n", - " spark.catalog.dropTempView(temp_view_name)\n", - " \n", - " return total_updated\n", - "\n", - "# Process each column individually (streaming approach)\n", - "print(\"Starting column-by-column tokenization with streaming chunked processing...\")\n", - "\n", - "for column in pii_columns:\n", - " print(f\"\\nProcessing column: {column}\")\n", - " \n", - " # Get total count first for progress tracking\n", - " total_count = spark.sql(f\"\"\"\n", - " SELECT COUNT(*) as count \n", - " FROM `{table_name}` \n", - " WHERE `{column}` IS NOT NULL\n", - " \"\"\").collect()[0]['count']\n", - " \n", - " if total_count == 0:\n", - " print(f\" No data found in column {column}\")\n", - " continue\n", - " \n", - " print(f\" Found {total_count:,} total values to tokenize\")\n", - " \n", - " # Process in streaming chunks to avoid memory issues\n", - " all_update_data = [] # Collect all updates before final MERGE\n", - " processed_count = 0\n", - " \n", - " for offset in range(0, total_count, COLLECT_BATCH_SIZE):\n", - " chunk_size = min(COLLECT_BATCH_SIZE, total_count - offset)\n", - " print(f\" Processing chunk {offset//COLLECT_BATCH_SIZE + 1} ({chunk_size:,} records, offset {offset:,})...\")\n", - " \n", - " # Get chunk of data from Spark\n", - " chunk_df = spark.sql(f\"\"\"\n", - " SELECT customer_id, `{column}` \n", - " FROM `{table_name}` \n", - " WHERE `{column}` IS NOT NULL \n", - " ORDER BY customer_id\n", - " LIMIT {chunk_size} OFFSET {offset}\n", - " \"\"\")\n", - " \n", - " chunk_rows = chunk_df.collect()\n", - " if not chunk_rows:\n", - " continue\n", - " \n", - " # Extract customer IDs and values for this chunk\n", - " chunk_customer_ids = [row['customer_id'] for row in chunk_rows]\n", - " chunk_column_values = [row[column] for row in chunk_rows]\n", - " \n", - " # Tokenize this chunk's values in Skyflow API batches\n", - " chunk_tokens = []\n", - " if len(chunk_column_values) <= SKYFLOW_BATCH_SIZE: # Single API batch\n", - " chunk_tokens = tokenize_column_values(f\"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}\", chunk_column_values)\n", - " else: # Multiple API batches within this chunk\n", - " for i in range(0, len(chunk_column_values), SKYFLOW_BATCH_SIZE):\n", - " api_batch_values = chunk_column_values[i:i + SKYFLOW_BATCH_SIZE]\n", - " api_batch_tokens = tokenize_column_values(f\"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}_api_{i//SKYFLOW_BATCH_SIZE + 1}\", api_batch_values)\n", - " chunk_tokens.extend(api_batch_tokens)\n", - " \n", - " # Verify token count matches input count (fail fast)\n", - " if len(chunk_tokens) != len(chunk_customer_ids):\n", - " raise RuntimeError(f\"Token count mismatch: got {len(chunk_tokens)} tokens for {len(chunk_customer_ids)} input values\")\n", - " \n", - " # Collect update data for rows that changed in this chunk\n", - " chunk_original_map = {chunk_customer_ids[i]: chunk_column_values[i] for i in range(len(chunk_customer_ids))}\n", - " \n", - " for i, (customer_id, token) in enumerate(zip(chunk_customer_ids, chunk_tokens)):\n", - " if token and str(token) != str(chunk_original_map[customer_id]):\n", - " all_update_data.append((customer_id, token))\n", - " \n", - " processed_count += len(chunk_rows)\n", - " print(f\" Processed {processed_count:,}/{total_count:,} records ({(processed_count/total_count)*100:.1f}%)\")\n", - " \n", - " # Perform final chunked MERGE operations for all collected updates\n", - " if all_update_data:\n", - " print(f\" Performing final chunked MERGE of {len(all_update_data):,} changed rows...\")\n", - " total_updated = perform_chunked_merge(table_name, column, all_update_data)\n", - " print(f\" ✓ Successfully updated {total_updated:,} rows in column {column}\")\n", - " else:\n", - " print(f\" No updates needed - all tokens match original values\")\n", - "\n", - "print(\"\\nOptimized streaming tokenization completed!\")\n", - "\n", - "# Verify results\n", - "print(\"\\nFinal verification:\")\n", - "for column in pii_columns:\n", - " sample_df = spark.sql(f\"\"\"\n", - " SELECT `{column}`, COUNT(*) as count \n", - " FROM `{table_name}` \n", - " GROUP BY `{column}` \n", - " LIMIT 3\n", - " \"\"\")\n", - " print(f\"\\nSample values in {column}:\")\n", - " sample_df.show(truncate=False)\n", - "\n", - "total_rows = spark.sql(f\"SELECT COUNT(*) as count FROM `{table_name}`\").collect()[0][\"count\"]\n", - "print(f\"\\nTable size: {total_rows:,} total rows\")\n", - "\n", - "print(f\"Optimized streaming tokenization completed for {len(pii_columns)} columns\")" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f2c0335 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +# Databricks SDK and core dependencies +databricks-sdk>=0.20.0 +python-dotenv>=1.0.0 +click>=8.0.0 + +# Additional utilities +requests>=2.31.0 +pydantic>=2.0.0 +rich>=13.0.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..52da91e --- /dev/null +++ b/setup.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +""" +Skyflow Databricks Integration Setup Tool + +Modern Python CLI using Databricks SDK for secure PII tokenization. +""" + +import sys +import click +from pathlib import Path +from rich.console import Console +from rich.traceback import install + +# Install rich traceback handler +install() +console = Console() + +# Add the skyflow_databricks directory to Python path +skyflow_dir = Path(__file__).parent / "skyflow_databricks" +sys.path.insert(0, str(skyflow_dir)) + +from cli.commands import CreateCommand, DestroyCommand, VerifyCommand +from config.config import SetupConfig +from utils.logging import setup_logging + + +@click.group() +@click.option('--verbose', '-v', is_flag=True, help='Enable verbose logging') +@click.option('--config', '-c', default='.env.local', help='Configuration file path') +@click.pass_context +def cli(ctx, verbose, config): + """Skyflow Databricks Integration Setup Tool.""" + + # Setup logging + log_level = "DEBUG" if verbose else "INFO" + logger = setup_logging(log_level) + + # Store config in context + ctx.ensure_object(dict) + ctx.obj['config_file'] = config + ctx.obj['logger'] = logger + + +@cli.command() +@click.argument('prefix') +@click.pass_context +def create(ctx, prefix): + """Create a new Skyflow Databricks integration.""" + try: + config = SetupConfig(ctx.obj['config_file']) + command = CreateCommand(prefix, config) + success = command.execute() + sys.exit(0 if success else 1) + except Exception as e: + console.print(f"[red]Error: {e}[/red]") + sys.exit(1) + + +@cli.command() +@click.argument('prefix') +@click.pass_context +def destroy(ctx, prefix): + """Destroy an existing Skyflow Databricks integration.""" + + try: + config = SetupConfig(ctx.obj['config_file']) + command = DestroyCommand(prefix, config) + success = command.execute() + sys.exit(0 if success else 1) + except Exception as e: + console.print(f"[red]Error: {e}[/red]") + sys.exit(1) + + +@cli.command() +@click.argument('prefix') +@click.pass_context +def recreate(ctx, prefix): + """Recreate a Skyflow Databricks integration (destroy then create).""" + + try: + config = SetupConfig(ctx.obj['config_file']) + + # Destroy first + console.print("[bold blue]Phase 1: Destroying existing resources[/bold blue]") + destroy_command = DestroyCommand(prefix, config) + destroy_success = destroy_command.execute() + + if not destroy_success: + console.print("[yellow]Warning: Destroy phase had some errors, continuing with create...[/yellow]") + + # Create new + console.print("\n[bold blue]Phase 2: Creating new resources[/bold blue]") + create_command = CreateCommand(prefix, config) + create_success = create_command.execute() + + sys.exit(0 if create_success else 1) + + except Exception as e: + console.print(f"[red]Error: {e}[/red]") + sys.exit(1) + + +@cli.command() +@click.argument('prefix') +@click.pass_context +def verify(ctx, prefix): + """Verify an existing Skyflow Databricks integration.""" + try: + config = SetupConfig(ctx.obj['config_file']) + command = VerifyCommand(prefix, config) + success = command.execute() + sys.exit(0 if success else 1) + except Exception as e: + console.print(f"[red]Error: {e}[/red]") + sys.exit(1) + + +@cli.command() +@click.pass_context +def config_test(ctx): + """Test configuration and Databricks connection.""" + try: + config = SetupConfig(ctx.obj['config_file']) + console.print("[blue]Testing configuration...[/blue]") + config.validate() + + # Test connection + user = config.client.current_user.me() + console.print(f"✓ Connected to Databricks as: {user.user_name}") + console.print(f"✓ Workspace: {config.databricks.host}") + console.print(f"✓ Warehouse ID: {config.databricks.warehouse_id}") + + console.print("[bold green]✓ Configuration test passed[/bold green]") + + except Exception as e: + console.print(f"[red]Configuration test failed: {e}[/red]") + sys.exit(1) + + +if __name__ == '__main__': + cli() \ No newline at end of file diff --git a/setup.sh b/setup.sh deleted file mode 100755 index d02fcbb..0000000 --- a/setup.sh +++ /dev/null @@ -1,1039 +0,0 @@ -#!/bin/bash - -# Check if action and prefix are provided correctly -if [[ "$1" != "create" && "$1" != "destroy" && "$1" != "recreate" ]]; then - echo "Invalid action. Use 'create', 'destroy', or 'recreate'." - echo "Usage: ./setup.sh " - echo "Example: ./setup.sh create demo" - exit 1 -fi - -if [[ "$1" == "create" && -z "$2" ]]; then - echo "Error: Prefix is required for create action" - echo "Usage: ./setup.sh create " - exit 1 -fi - -# Set prefix if provided -if [[ -n "$2" ]]; then - # Convert to lowercase and replace any non-alphanumeric chars with underscore - export PREFIX=$(echo "$2" | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z0-9]/_/g') - echo "Using prefix: $PREFIX" -fi - -# Function to load configuration values automatically from .env.local -load_config() { - # Source config.sh to get default values and load from .env.local - source "$(dirname "$0")/config.sh" - - echo - - # Set variables from .env.local or defaults, and report what was loaded - export DATABRICKS_HOST=${DEFAULT_DATABRICKS_HOST} - if [[ -n "$DATABRICKS_HOST" ]]; then - echo "✓ DATABRICKS_HOST loaded from .env.local" - fi - - export DATABRICKS_TOKEN=${DEFAULT_DATABRICKS_TOKEN} - if [[ -n "$DATABRICKS_TOKEN" ]]; then - echo "✓ DATABRICKS_TOKEN loaded from .env.local" - fi - - export WAREHOUSE_ID=${DEFAULT_WAREHOUSE_ID} - if [[ -n "$WAREHOUSE_ID" ]]; then - echo "✓ WAREHOUSE_ID loaded from .env.local" - fi - - export SKYFLOW_VAULT_URL=${DEFAULT_SKYFLOW_VAULT_URL} - if [[ -n "$SKYFLOW_VAULT_URL" ]]; then - echo "✓ SKYFLOW_VAULT_URL loaded from .env.local" - fi - - export SKYFLOW_VAULT_ID=${DEFAULT_SKYFLOW_VAULT_ID} - if [[ -n "$SKYFLOW_VAULT_ID" ]]; then - echo "✓ SKYFLOW_VAULT_ID loaded from .env.local" - fi - - export SKYFLOW_PAT_TOKEN=${DEFAULT_SKYFLOW_PAT_TOKEN} - if [[ -n "$SKYFLOW_PAT_TOKEN" ]]; then - echo "✓ SKYFLOW_PAT_TOKEN loaded from .env.local" - fi - - export SKYFLOW_TABLE=${DEFAULT_SKYFLOW_TABLE} - if [[ -n "$SKYFLOW_TABLE" ]]; then - echo "✓ SKYFLOW_TABLE loaded from .env.local" - fi - - # Group mappings with defaults - export PLAIN_TEXT_GROUPS=${DEFAULT_PLAIN_TEXT_GROUPS:-"auditor"} - export MASKED_GROUPS=${DEFAULT_MASKED_GROUPS:-"customer_service"} - export REDACTED_GROUPS=${DEFAULT_REDACTED_GROUPS:-"marketing"} - echo "✓ Group mappings set: PLAIN_TEXT=${PLAIN_TEXT_GROUPS}, MASKED=${MASKED_GROUPS}, REDACTED=${REDACTED_GROUPS}" - - echo -e "\nConfiguration loaded successfully." -} - -# Function to substitute environment variables in content -substitute_variables() { - local content=$1 - echo "$content" | envsubst -} - -# Function to setup Unity Catalog connections via SQL file -setup_uc_connections() { - echo "Creating Unity Catalog connections via SQL..." - - # Read SQL file and process each connection separately (direct API call to avoid catalog context) - local sql_content=$(cat "sql/setup/create_uc_connections.sql") - local processed_content=$(substitute_variables "$sql_content") - - # Execute consolidated connection creation with detailed logging (direct API call to avoid catalog context) - echo "Executing consolidated Skyflow connection SQL without catalog context..." - local connection_response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/sql/statements" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{\"statement\":$(echo "$processed_content" | python3 -c 'import json,sys; print(json.dumps(sys.stdin.read()))'),\"warehouse_id\":\"${WAREHOUSE_ID}\"}") - - if echo "$connection_response" | grep -q '"state":"SUCCEEDED"'; then - echo "✓ Created UC consolidated connection: skyflow_conn" - local connection_success=true - else - echo "❌ ERROR: Consolidated connection creation failed" - echo "SQL statement was: $processed_content" - echo "Response: $connection_response" - local connection_success=false - fi - - # Verify connection actually exists after creation - echo "Verifying UC connection was actually created..." - local actual_connections=$(curl -s -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - "${DATABRICKS_HOST}/api/2.1/unity-catalog/connections" | \ - python3 -c " -import json, sys -try: - data = json.load(sys.stdin) - skyflow_conns = [c['name'] for c in data.get('connections', []) if 'skyflow' in c['name'].lower()] - print(' '.join(skyflow_conns)) -except: - print('') -") - - if echo "$actual_connections" | grep -q "skyflow_conn"; then - echo "✓ Verified skyflow_conn exists" - else - echo "❌ skyflow_conn NOT FOUND after creation" - connection_success=false - fi - - # Return success only if connection succeeded - if [ "$connection_success" = true ]; then - echo "✓ UC consolidated connection created successfully via SQL" - return 0 - else - echo "❌ Failed to create required UC connection" - echo "Connection must be created successfully for setup to proceed" - exit 1 - fi -} - -# Function to setup Unity Catalog secrets via Databricks REST API -setup_uc_secrets() { - echo "Creating Unity Catalog secrets scope..." - - # Create UC-backed secrets scope - local scope_response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/secrets/scopes/create" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d '{ - "scope": "skyflow-secrets", - "scope_backend_type": "UC" - }') - - # Check if scope creation failed (ignore if already exists) - if echo "$scope_response" | grep -q '"error_code":"RESOURCE_ALREADY_EXISTS"'; then - echo "✓ Secrets scope already exists" - elif echo "$scope_response" | grep -q '"error_code"'; then - echo "Error creating secrets scope: $scope_response" - return 1 - else - echo "✓ Created secrets scope successfully" - fi - - # Create individual secrets (only ones actually needed for tokenization/detokenization) - local secrets=( - "skyflow_pat_token:${SKYFLOW_PAT_TOKEN}" - "skyflow_vault_url:${SKYFLOW_VAULT_URL}" - "skyflow_vault_id:${SKYFLOW_VAULT_ID}" - "skyflow_table:${SKYFLOW_TABLE}" - "skyflow_table_column:${SKYFLOW_TABLE_COLUMN}" - ) - - for secret_pair in "${secrets[@]}"; do - IFS=':' read -r key value <<< "$secret_pair" - echo "Creating secret: $key" - - local secret_response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/secrets/put" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{ - \"scope\": \"skyflow-secrets\", - \"key\": \"$key\", - \"string_value\": \"$value\" - }") - - if echo "$secret_response" | grep -q '"error_code"'; then - echo "Warning: Error creating secret $key: $secret_response" - else - echo "✓ Created secret: $key" - fi - done - - return 0 -} - -# Function to create a notebook using Databricks REST API -create_notebook() { - local path=$1 - local source_file=$2 - - # Check if source file exists - if [[ ! -f "$source_file" ]]; then - echo "Error: Source file not found: $source_file" - return 1 - fi - - # Read file content and substitute variables - local content=$(cat "$source_file") - local processed_content=$(substitute_variables "$content") - - # Base64 encode the processed content - if [[ "$OSTYPE" == "darwin"* ]]; then - # macOS - encoded_content=$(echo "$processed_content" | base64) - else - # Linux - try -w 0 first, fall back to plain base64 - encoded_content=$(echo "$processed_content" | base64 -w 0 2>/dev/null || echo "$processed_content" | base64 | tr -d '\n') - fi - - local response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/workspace/import" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{ - \"path\": \"${path}\", - \"format\": \"JUPYTER\", - \"content\": \"${encoded_content}\", - \"overwrite\": true - }") - - # Check for errors - if echo "$response" | grep -q "error"; then - echo "Error creating notebook:" - echo "$response" - return 1 - else - echo "Notebook created successfully at ${path}" - fi -} - -# Function to execute SQL using Databricks REST API -execute_sql() { - local sql_file=$1 - - local statements=() - - # Handle direct SQL statements vs SQL files - if [[ "$sql_file" == DROP* || "$sql_file" == DESCRIBE* ]]; then - # For DROP and DESCRIBE commands, use the command directly - statements+=("$sql_file") - else - # For SQL files, read and process content - if [[ ! -f "$sql_file" ]]; then - echo "Error: SQL file not found: $sql_file" - return 1 - fi - - # Read and process SQL file - local sql_content=$(cat "$sql_file") - local processed_content=$(substitute_variables "$sql_content") - local current_statement="" - - # Split into statements, handling multi-line SQL properly - while IFS= read -r line || [[ -n "$line" ]]; do - # Skip empty lines and comments - if [[ -z "${line// }" ]] || [[ "$line" =~ ^[[:space:]]*-- ]]; then - continue - fi - - # Add line to current statement with exact whitespace preservation - if [[ -z "$current_statement" ]]; then - current_statement="${line}" - else - # Preserve exact line including all whitespace - current_statement+=$'\n'"${line}" - fi - - # If line ends with semicolon, it's end of statement - if [[ "$line" =~ \;[[:space:]]*$ ]]; then - if [[ -n "$current_statement" ]]; then - statements+=("$current_statement") - fi - current_statement="" - fi - done <<< "$processed_content" - fi - - # Execute each statement - for statement in "${statements[@]}"; do - # Properly escape for JSON while preserving newlines - # Write statement to temp file to avoid shell interpretation - local temp_file=$(mktemp) - printf "%s" "$statement" > "$temp_file" - - # Use Python to properly escape while preserving exact formatting - local json_statement=$(python3 -c ' -import json, sys -with open("'"$temp_file"'", "r") as f: - content = f.read() -print(json.dumps(content)) -') - rm "$temp_file" - - # Use dedicated catalog if available, otherwise main - local catalog_context="${CATALOG_NAME:-main}" - local response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/sql/statements" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{\"statement\":${json_statement},\"catalog\":\"${catalog_context}\",\"schema\":\"default\",\"warehouse_id\":\"${WAREHOUSE_ID}\"}") - - # Check for errors in response - if echo "$response" | grep -q "error"; then - echo "Error executing SQL:" - echo "Statement was: $statement" - echo "Response: $response" - return 1 - fi - done -} - -# Function to run a notebook using Databricks Runs API -run_notebook() { - local notebook_path=$1 - local table_name=$2 - local pii_columns=$3 - local batch_size=${4:-${SKYFLOW_BATCH_SIZE}} # Use provided batch size or env default - - echo "Running notebook: ${notebook_path}" - echo "Batch size: ${batch_size}" - - # Create job run with notebook parameters using serverless compute - # Note: For serverless, we must use multi-task format with tasks array - local run_response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.1/jobs/runs/submit" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{ - \"run_name\": \"Serverless_Tokenize_${table_name}_$(date +%s)\", - \"tasks\": [ - { - \"task_key\": \"tokenize_task\", - \"notebook_task\": { - \"notebook_path\": \"${notebook_path}\", - \"source\": \"WORKSPACE\", - \"base_parameters\": { - \"table_name\": \"${table_name}\", - \"pii_columns\": \"${pii_columns}\", - \"batch_size\": \"${batch_size}\" - } - }, - \"timeout_seconds\": 1800 - } - ] - }") - - # Extract run ID - local run_id=$(echo "$run_response" | python3 -c " -import json, sys -try: - data = json.load(sys.stdin) - if 'run_id' in data: - print(data['run_id']) - else: - print('ERROR: ' + str(data)) - sys.exit(1) -except Exception as e: - print('ERROR: ' + str(e)) - sys.exit(1) -") - - if [[ "$run_id" == ERROR* ]]; then - echo "Failed to start notebook run: $run_id" - return 1 - fi - - echo "Started notebook run with ID: $run_id" - - # Extract workspace ID from hostname for task URL - local workspace_id=$(echo "$DATABRICKS_HOST" | sed 's/https:\/\/dbc-//' | sed 's/-.*\.cloud\.databricks\.com.*//') - echo "View live logs: ${DATABRICKS_HOST}/jobs/runs/${run_id}?o=${workspace_id}" - - # Wait for run to complete - echo "Waiting for tokenization to complete..." - local max_wait=900 # 15 minutes - local wait_time=0 - - while [[ $wait_time -lt $max_wait ]]; do - local status_response=$(curl -s -X GET "${DATABRICKS_HOST}/api/2.1/jobs/runs/get?run_id=${run_id}" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}") - - local run_state=$(echo "$status_response" | python3 -c " -import json, sys -try: - data = json.load(sys.stdin) - state = data.get('state', {}).get('life_cycle_state', 'UNKNOWN') - print(state) -except: - print('UNKNOWN') -") - - case $run_state in - "TERMINATED") - local result_state=$(echo "$status_response" | python3 -c " -import json, sys -try: - data = json.load(sys.stdin) - result = data.get('state', {}).get('result_state', 'UNKNOWN') - print(result) -except: - print('UNKNOWN') -") - if [[ "$result_state" == "SUCCESS" ]]; then - echo "✅ Tokenization completed successfully" - return 0 - else - echo "❌ Tokenization failed with result: $result_state" - return 1 - fi - ;; - "INTERNAL_ERROR"|"FAILED"|"TIMEDOUT"|"CANCELED"|"SKIPPED") - echo "❌ Tokenization run failed with state: $run_state" - return 1 - ;; - *) - echo "Tokenization in progress... (state: $run_state)" - sleep 30 - wait_time=$((wait_time + 30)) - ;; - esac - done - - echo "❌ Tokenization timed out after ${max_wait} seconds" - return 1 -} - -# Function to import dashboard using Databricks Lakeview API -create_dashboard() { - local path=$1 - local source_file=$2 - - # Check if source file exists - if [[ ! -f "$source_file" ]]; then - echo "Error: Dashboard file not found: $source_file" - return 1 - fi - - # Read file content and substitute variables - local content=$(cat "$source_file") - - # Substitute variables and stringify dashboard content - local processed_content=$(substitute_variables "$content" | python3 -c 'import json,sys; print(json.dumps(json.dumps(json.load(sys.stdin))))') - - # Check if dashboard already exists and delete it - dashboards=$(curl -s -X GET "${DATABRICKS_HOST}/api/2.0/lakeview/dashboards" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json") - - existing_id=$(echo "$dashboards" | python3 -c " -import sys, json -data = json.load(sys.stdin) -prefix = '${PREFIX}' -ids = [d['dashboard_id'] for d in data.get('dashboards', []) - if d.get('display_name', '') == prefix + '_customer_insights'] -print(ids[0] if ids else '') -") - - if [[ -n "$existing_id" ]]; then - echo "Deleting existing dashboard..." - curl -s -X DELETE "${DATABRICKS_HOST}/api/2.0/lakeview/dashboards/${existing_id}" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" - - # Wait for deletion to complete - echo "Waiting for dashboard deletion to complete..." - sleep 5 - - # Verify deletion - local max_retries=3 - local retry_count=0 - while [[ $retry_count -lt $max_retries ]]; do - dashboards=$(curl -s -X GET "${DATABRICKS_HOST}/api/2.0/lakeview/dashboards" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json") - - still_exists=$(echo "$dashboards" | python3 -c " -import sys, json -data = json.load(sys.stdin) -prefix = '${PREFIX}' -ids = [d['dashboard_id'] for d in data.get('dashboards', []) - if d.get('display_name', '') == prefix + '_customer_insights'] -print('true' if ids else 'false') -") - - if [[ "$still_exists" == "false" ]]; then - break - fi - - echo "Dashboard still exists, waiting..." - sleep 5 - ((retry_count++)) - done - fi - - # Create dashboard using Lakeview API - local payload=$(echo "{ - \"display_name\": \"${PREFIX}_customer_insights\", - \"warehouse_id\": \"${WAREHOUSE_ID}\", - \"serialized_dashboard\": ${processed_content}, - \"parent_path\": \"/Shared\" - }") - - local response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/lakeview/dashboards" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "$payload") - - # Check for errors - if echo "$response" | grep -q "error\|Error\|ERROR"; then - echo "Error creating dashboard:" - echo "$response" - return 1 - fi - - # Extract dashboard ID from response, handling both stdout and stderr from curl - local dashboard_id=$(echo "$response" | grep -o '"dashboard_id":"[^"]*"' | cut -d'"' -f4) - if [[ -z "$dashboard_id" ]]; then - echo "Error: Could not extract dashboard ID from response" - return 1 - fi - echo "$dashboard_id" -} - -# Function to create metastore -create_metastore() { - echo "Creating dedicated metastore..." - local metastore_name="${PREFIX}_metastore" - - # Create metastore with only required fields (no S3 bucket or IAM role) - local metastore_response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.1/unity-catalog/metastores" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{ - \"name\": \"${metastore_name}\", - \"region\": \"${DATABRICKS_METASTORE_REGION:-us-west-1}\" - }") - - # Extract metastore ID - local metastore_id=$(echo "$metastore_response" | python3 -c " -import json, sys -try: - data = json.load(sys.stdin) - if 'metastore_id' in data: - print(data['metastore_id']) - else: - print('ERROR: ' + str(data)) - sys.exit(1) -except Exception as e: - print('ERROR: ' + str(e)) - sys.exit(1) -") - - if [[ "$metastore_id" == ERROR* ]]; then - echo "Failed to create metastore: $metastore_id" - return 1 - fi - - echo "Metastore created with ID: $metastore_id" - export METASTORE_ID="$metastore_id" - - # Assign metastore to current workspace - echo "Assigning metastore to workspace..." - local assignment_response=$(curl -s -X PUT "${DATABRICKS_HOST}/api/2.1/unity-catalog/current-metastore-assignment" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{\"metastore_id\": \"${metastore_id}\"}") - - # Wait for assignment to propagate - echo "Waiting for metastore assignment..." - sleep 10 - - return 0 -} - -# Function to destroy metastore -destroy_metastore() { - echo "Finding and destroying metastore..." - local metastore_name="${PREFIX}_metastore" - - # Get metastore ID by name - local metastores=$(curl -s -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - "${DATABRICKS_HOST}/api/2.1/unity-catalog/metastores") - - local metastore_id=$(echo "$metastores" | python3 -c " -import json, sys -try: - data = json.load(sys.stdin) - for ms in data.get('metastores', []): - if ms.get('name') == '${metastore_name}': - print(ms.get('metastore_id', '')) - break - else: - print('') -except: - print('') -") - - if [[ -n "$metastore_id" ]]; then - echo "Found metastore ID: $metastore_id" - - # Delete metastore - echo "Deleting metastore..." - curl -s -X DELETE "${DATABRICKS_HOST}/api/2.1/unity-catalog/metastores/${metastore_id}" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -d '{"force": true}' - - echo "Metastore deletion initiated" - return 0 - else - echo "Metastore ${metastore_name} not found" - return 0 - fi -} - -# Function to check directory existence -check_directories() { - local dirs=("notebooks" "sql" "sql/setup" "sql/destroy" "sql/verify" "dashboards") - local missing=0 - - for dir in "${dirs[@]}"; do - if [[ ! -d "$dir" ]]; then - echo "Error: Required directory not found: $dir" - missing=1 - fi - done - - if [[ $missing -eq 1 ]]; then - echo "Please ensure all required directories exist before running setup." - exit 1 - fi -} - -create_components() { - echo "Creating resources with prefix: ${PREFIX}" - - # Verify required directories exist - check_directories || exit 1 - - # Create dedicated catalog for this instance (instead of metastore) - echo "Creating dedicated catalog: ${PREFIX}_catalog" - execute_sql "sql/setup/create_catalog.sql" || exit 1 - - # Use our dedicated catalog instead of main - export CATALOG_NAME="${PREFIX}_catalog" - - # Create Shared directory if it doesn't exist - echo "Creating Shared directory..." - curl -s -X POST "${DATABRICKS_HOST}/api/2.0/workspace/mkdirs" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{\"path\": \"/Workspace/Shared\"}" - - # Catalog and schema setup complete - - # Create sample table - echo "Creating sample table..." - execute_sql "sql/setup/create_sample_table.sql" || exit 1 - - # Create tokenization notebook - echo "Creating tokenization notebook..." - create_notebook "/Workspace/Shared/${PREFIX}_tokenize_table" "notebooks/notebook_tokenize_table.ipynb" || exit 1 - - # Setup Unity Catalog secrets via REST API - echo "Setting up Unity Catalog secrets via Databricks API..." - setup_uc_secrets || exit 1 - - # Create UC connections (required) - echo "Creating Unity Catalog connections..." - setup_uc_connections || exit 1 - - echo "Creating Pure SQL detokenization functions with UC connections..." - execute_sql "sql/setup/setup_uc_connections_api.sql" || exit 1 - echo "✓ Using UC connections approach (pure SQL, highest performance)" - - # Brief pause to ensure function is fully created - echo "Verifying function creation..." - sleep 5 - - # Verify Unity Catalog functions are created - echo "Verifying Unity Catalog detokenization functions..." - execute_sql "sql/verify/verify_functions.sql" || exit 1 - - # Check if table exists before applying masks - echo "Verifying table exists..." - execute_sql "sql/verify/verify_table.sql" || exit 1 - - # The conditional detokenization function is already created by create_uc_detokenize_functions.sql - echo "✓ Unity Catalog conditional detokenization functions created" - - # Tokenize the sample data FIRST (before applying masks) - echo "Tokenizing PII data in sample table..." - local pii_columns="first_name" - local table_name="${PREFIX}_catalog.default.${PREFIX}_customer_data" - run_notebook "/Workspace/Shared/${PREFIX}_tokenize_table" "${table_name}" "${pii_columns}" "${SKYFLOW_BATCH_SIZE}" || exit 1 - - # Apply column masks to PII columns AFTER tokenization - echo "Applying column masks to tokenized PII columns..." - execute_sql "sql/setup/apply_column_masks.sql" || exit 1 - - # Create dashboard - echo "Creating dashboard..." - dashboard_id=$(create_dashboard "${PREFIX}_customer_insights" "dashboards/customer_insights_dashboard.lvdash.json" | tail -n 1) || exit 1 - echo "Dashboard created successfully" - echo "Dashboard URL: ${DATABRICKS_HOST}/sql/dashboardsv3/${dashboard_id}" - - echo "Setup complete! Created resources with prefix: ${PREFIX}" - echo " -Resources created: -1. Dedicated Unity Catalog: ${PREFIX}_catalog -2. Sample table: ${PREFIX}_catalog.default.${PREFIX}_customer_data (with tokenized PII and column masks applied) -3. Tokenization notebook: - - /Workspace/Shared/${PREFIX}_tokenize_table (serverless-optimized) -4. Unity Catalog Infrastructure: - - SQL-created HTTP connection: skyflow_conn (consolidated for tokenization and detokenization) - - UC-backed secrets scope: skyflow-secrets (contains PAT token, vault details) - - Bearer token authentication with proper secret() references -5. Pure SQL Functions: - - ${PREFIX}_catalog.default.${PREFIX}_skyflow_uc_detokenize (direct Skyflow API via UC connections) - - ${PREFIX}_catalog.default.${PREFIX}_skyflow_conditional_detokenize (role-based access control) - - ${PREFIX}_catalog.default.${PREFIX}_skyflow_mask_detokenize (column mask wrapper) -6. Column masks applied to ALL PII columns - only 'auditor' group sees detokenized data -7. Dashboard: ${PREFIX}_customer_insights (catalog-qualified queries) -8. ✅ PII data automatically tokenized during setup - -Usage: -1. Data is already tokenized and ready to use! - -2. To query data with automatic detokenization: - Run: SELECT * FROM ${PREFIX}_catalog.default.${PREFIX}_customer_data - (PII columns automatically detokenized based on user role) - -3. For bulk detokenization: - Run: SELECT ${PREFIX}_catalog.default.${PREFIX}_skyflow_bulk_detokenize(array('token1', 'token2'), current_user()) - -3. To view data: - Dashboard URL: ${DATABRICKS_HOST}/sql/dashboardsv3/${dashboard_id:-""} -" -} - -destroy_components() { - echo "Destroying resources with prefix: ${PREFIX}" - local failed_deletions=() - local successful_deletions=() - - # Drop Unity Catalog functions (including test functions) - echo "Dropping Unity Catalog detokenization functions..." - - # Set catalog context to dedicated catalog if it exists, otherwise skip function drops - local catalog_name="${PREFIX}_catalog" - local catalog_check_response=$(curl -s -X GET "${DATABRICKS_HOST}/api/2.1/unity-catalog/catalogs/${catalog_name}" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}") - - if echo "$catalog_check_response" | grep -q '"name"'; then - echo "Using catalog: ${catalog_name}" - CATALOG_NAME="${catalog_name}" execute_sql "sql/destroy/drop_functions.sql" - else - echo "Dedicated catalog ${catalog_name} not found, skipping function drops" - fi - - # Clean up UC secrets scope - echo "Cleaning up Unity Catalog secrets..." - local delete_scope_response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/secrets/scopes/delete" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d '{"scope": "skyflow-secrets"}') - - if echo "$delete_scope_response" | grep -q '"error_code"'; then - echo "Warning: Could not delete secrets scope: $delete_scope_response" - else - echo "✓ Deleted secrets scope: skyflow-secrets" - fi - - # Drop Unity Catalog connections (both SQL-created and REST-created) - echo "Cleaning up Unity Catalog connections..." - local connections_response=$(curl -s -X GET "${DATABRICKS_HOST}/api/2.1/unity-catalog/connections" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}") - - # Extract connection names that are Skyflow-related - local connection_names=$(echo "$connections_response" | python3 -c " -import json, sys -try: - data = json.load(sys.stdin) - names = [conn['name'] for conn in data.get('connections', []) - if conn['name'] in ['skyflow_conn', 'skyflow_tokenize_conn', 'skyflow_detokenize_conn']] - print('\n'.join(names)) -except Exception as e: - print(f'Error extracting names: {e}') - print('') -") - - # Delete each matching connection - if [[ -n "$connection_names" ]]; then - while IFS= read -r conn_name; do - if [[ -n "$conn_name" ]]; then - echo "Deleting UC connection: ${conn_name}" - local delete_response=$(curl -s -X DELETE "${DATABRICKS_HOST}/api/2.1/unity-catalog/connections/${conn_name}" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}") - if echo "$delete_response" | grep -q '"error_code"'; then - echo " Warning: Error deleting $conn_name: $delete_response" - else - echo " ✓ Deleted connection: $conn_name" - fi - fi - done <<< "$connection_names" - else - echo "No connections found to delete" - fi - - # Verify UC function deletions - if execute_sql "sql/verify/check_functions_exist.sql" &>/dev/null; then - failed_deletions+=("Function: ${PREFIX}_skyflow_uc_detokenize") - failed_deletions+=("Function: ${PREFIX}_skyflow_mask_detokenize") - else - successful_deletions+=("Function: ${PREFIX}_skyflow_uc_detokenize") - successful_deletions+=("Function: ${PREFIX}_skyflow_mask_detokenize") - fi - - # Remove column masks before dropping table - echo "Removing column masks..." - execute_sql "sql/destroy/remove_column_masks.sql" &>/dev/null || true - - # Drop the table - echo "Dropping sample table..." - execute_sql "sql/destroy/drop_table.sql" - # Verify table deletion - if execute_sql "sql/verify/check_table_exists.sql" &>/dev/null; then - failed_deletions+=("Table: ${PREFIX}_customer_data") - else - successful_deletions+=("Table: ${PREFIX}_customer_data") - fi - - # Delete notebook - echo "Deleting tokenization notebook..." - local notebook_paths=("/Workspace/Shared/${PREFIX}_tokenize_table") - for notebook_path in "${notebook_paths[@]}"; do - echo "Deleting notebook: ${notebook_path}" - local response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/workspace/delete" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{\"path\": \"${notebook_path}\", \"recursive\": true}") - - # Verify notebook deletion by trying to get its info - local verify_response=$(curl -s -X GET "${DATABRICKS_HOST}/api/2.0/workspace/get-status" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{\"path\": \"${notebook_path}\"}") - - if echo "$verify_response" | grep -q "error_code.*RESOURCE_DOES_NOT_EXIST"; then - successful_deletions+=("Notebook: ${notebook_path}") - else - failed_deletions+=("Notebook: ${notebook_path}") - fi - done - - # Delete dashboard - echo "Deleting dashboard..." - # Get all dashboards and find matching ones - dashboards=$(curl -s -X GET "${DATABRICKS_HOST}/api/2.0/lakeview/dashboards" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json") - - # Use Python to handle the JSON parsing and find all matching dashboards - matching_ids=$(echo "$dashboards" | python3 -c " -import sys, json -data = json.load(sys.stdin) -prefix = '${PREFIX}' -ids = [d['dashboard_id'] for d in data.get('dashboards', []) - if d.get('display_name', '').startswith(prefix)] -print('\n'.join(ids)) -") - - # Delete each matching dashboard - while IFS= read -r dashboard_id; do - if [[ -n "$dashboard_id" ]]; then - echo "Deleting dashboard with ID: ${dashboard_id}" - curl -s -X DELETE "${DATABRICKS_HOST}/api/2.0/lakeview/dashboards/${dashboard_id}" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" - - # Verify dashboard deletion - sleep 2 # Brief pause to allow deletion to propagate - local verify_dashboards=$(curl -s -X GET "${DATABRICKS_HOST}/api/2.0/lakeview/dashboards" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json") - - local still_exists=$(echo "$verify_dashboards" | python3 -c " -import sys, json -data = json.load(sys.stdin) -dashboard_id = '${dashboard_id}' -exists = any(d['dashboard_id'] == dashboard_id for d in data.get('dashboards', [])) -print('true' if exists else 'false') -") - - if [[ "$still_exists" == "false" ]]; then - successful_deletions+=("Dashboard: ID ${dashboard_id}") - else - failed_deletions+=("Dashboard: ID ${dashboard_id}") - fi - fi - done <<< "$matching_ids" - - # Destroy dedicated catalog - echo "Destroying dedicated catalog..." - execute_sql "sql/destroy/cleanup_catalog.sql" || echo "Failed to drop catalog, continuing..." - - # Print summary - echo -e "\nDestroy Summary:" - if [[ ${#successful_deletions[@]} -gt 0 ]]; then - echo -e "\nSuccessfully deleted:" - printf '%s\n' "${successful_deletions[@]}" - fi - - if [[ ${#failed_deletions[@]} -gt 0 ]]; then - echo -e "\nFailed to delete:" - printf '%s\n' "${failed_deletions[@]}" - echo -e "\nWarning: Some resources could not be verified as deleted. They may require manual cleanup." - exit 1 - else - echo -e "\nAll resources successfully deleted!" - fi -} - -# Function to check if resources already exist -check_existing_resources() { - local has_existing=false - local existing_resources=() - - echo "Checking for existing resources with prefix: ${PREFIX}" - - # Check for catalog - local catalog_check=$(curl -s -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - "${DATABRICKS_HOST}/api/2.1/unity-catalog/catalogs" | \ - python3 -c " -import sys, json -try: - data = json.load(sys.stdin) - catalogs = [c['name'] for c in data.get('catalogs', [])] - exists = '${PREFIX}_catalog' in catalogs - print('true' if exists else 'false') -except: - print('false') -") - - if [[ "$catalog_check" == "true" ]]; then - has_existing=true - existing_resources+=("Catalog: ${PREFIX}_catalog") - fi - - # Check for secrets scope - local secrets_check=$(curl -s -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - "${DATABRICKS_HOST}/api/2.0/secrets/scopes/list" | \ - python3 -c " -import sys, json -try: - data = json.load(sys.stdin) - scopes = [s['name'] for s in data.get('scopes', [])] - exists = 'skyflow-secrets' in scopes - print('true' if exists else 'false') -except: - print('false') -") - - if [[ "$secrets_check" == "true" ]]; then - has_existing=true - existing_resources+=("Secrets scope: skyflow-secrets") - fi - - # Check for UC connections - local connections_check=$(curl -s -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - "${DATABRICKS_HOST}/api/2.1/unity-catalog/connections" | \ - python3 -c " -import sys, json -try: - data = json.load(sys.stdin) - connections = [c['name'] for c in data.get('connections', [])] - skyflow_conn_exists = 'skyflow_conn' in connections - # Also check for old connections in case of partial migration - tokenize_exists = 'skyflow_tokenize_conn' in connections - detokenize_exists = 'skyflow_detokenize_conn' in connections - print('true' if (skyflow_conn_exists or tokenize_exists or detokenize_exists) else 'false') -except: - print('false') -") - - if [[ "$connections_check" == "true" ]]; then - has_existing=true - existing_resources+=("UC connections: skyflow_*_conn") - fi - - # Check for notebooks - local notebook_path="/Workspace/Shared/${PREFIX}_tokenize_table" - local notebook_check=$(curl -s -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - "${DATABRICKS_HOST}/api/2.0/workspace/get-status" \ - -d "{\"path\": \"${notebook_path}\"}" | \ - python3 -c " -import sys, json -try: - data = json.load(sys.stdin) - exists = 'path' in data - print('true' if exists else 'false') -except: - print('false') -") - - if [[ "$notebook_check" == "true" ]]; then - has_existing=true - existing_resources+=("Notebook: ${notebook_path}") - fi - - if [[ "$has_existing" == "true" ]]; then - echo "" - echo "❌ ERROR: Existing resources found that would conflict with setup:" - printf ' - %s\n' "${existing_resources[@]}" - echo "" - echo "Please run one of the following commands first:" - echo " ./setup.sh destroy # Remove existing resources" - echo " ./setup.sh recreate ${PREFIX} # Remove and recreate all resources" - echo "" - exit 1 - else - echo "✅ No conflicting resources found. Proceeding with setup..." - fi -} - -# Main logic -load_config - -if [[ "$1" == "create" ]]; then - check_existing_resources - create_components -elif [[ "$1" == "destroy" ]]; then - destroy_components -elif [[ "$1" == "recreate" ]]; then - destroy_components - create_components -fi diff --git a/skyflow_databricks/__init__.py b/skyflow_databricks/__init__.py new file mode 100644 index 0000000..0dd73c1 --- /dev/null +++ b/skyflow_databricks/__init__.py @@ -0,0 +1 @@ +# Skyflow Databricks Setup Package \ No newline at end of file diff --git a/skyflow_databricks/cli/__init__.py b/skyflow_databricks/cli/__init__.py new file mode 100644 index 0000000..c38de79 --- /dev/null +++ b/skyflow_databricks/cli/__init__.py @@ -0,0 +1 @@ +# CLI interface \ No newline at end of file diff --git a/skyflow_databricks/cli/commands.py b/skyflow_databricks/cli/commands.py new file mode 100644 index 0000000..13f3250 --- /dev/null +++ b/skyflow_databricks/cli/commands.py @@ -0,0 +1,478 @@ +"""CLI command implementations for Databricks Skyflow integration.""" + +import time +from typing import Optional +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +from config.config import SetupConfig +from databricks_ops.unity_catalog import UnityCatalogManager +from databricks_ops.secrets import SecretsManager +from databricks_ops.sql import SQLExecutor +from databricks_ops.notebooks import NotebookManager +from databricks_ops.dashboards import DashboardManager +from utils.validation import validate_prefix, validate_required_files + +console = Console() + + +class BaseCommand: + """Base class for all commands.""" + + def __init__(self, prefix: str, config: Optional[SetupConfig] = None): + self.prefix = prefix + self.config = config or SetupConfig() + + # Validate prefix + is_valid, error = validate_prefix(prefix) + if not is_valid: + raise ValueError(f"Invalid prefix: {error}") + + def validate_environment(self): + """Validate environment and configuration.""" + try: + self.config.validate() + except ValueError as e: + console.print(f"[red]Configuration error: {e}[/red]") + raise + + +class CreateCommand(BaseCommand): + """Implementation of 'create' command.""" + + def execute(self) -> bool: + """Execute the create command.""" + console.print(Panel.fit( + f"Creating Skyflow Databricks Integration: [bold]{self.prefix}[/bold]", + style="green" + )) + + try: + # Always destroy first to ensure clean state + console.print(f"[dim]Cleaning up any existing '{self.prefix}' resources...[/dim]") + destroy_command = DestroyCommand(self.prefix, self.config) + destroy_command.execute() # Don't fail if destroy has issues + + # Validate environment + self.validate_environment() + + # Check required files exist + required_files = [ + "sql/setup/create_sample_table.sql", + "sql/setup/create_uc_connections.sql", + "sql/setup/setup_uc_connections_api.sql", + "sql/setup/apply_column_masks.sql", + "notebooks/notebook_tokenize_table.ipynb", + "dashboards/customer_insights_dashboard.lvdash.json" + ] + + files_exist, missing = validate_required_files(required_files) + if not files_exist: + console.print(f"[red]Missing required files: {', '.join(missing)}[/red]") + return False + + # Initialize managers + uc_manager = UnityCatalogManager(self.config.client) + secrets_manager = SecretsManager(self.config.client) + sql_executor = SQLExecutor(self.config.client, self.config.databricks.warehouse_id) + notebook_manager = NotebookManager(self.config.client) + dashboard_manager = DashboardManager(self.config.client) + + # Get substitutions + substitutions = self.config.get_substitutions(self.prefix) + + # Step 1: Create Unity Catalog resources + console.print("\n[bold blue]Step 1: Setting up Unity Catalog[/bold blue]") + if not self._setup_unity_catalog(uc_manager): + return False + + # Step 2: Setup secrets + console.print("\n[bold blue]Step 2: Setting up secrets[/bold blue]") + if not self._setup_secrets(secrets_manager): + return False + + # Step 3: Create connections + console.print("\n[bold blue]Step 3: Creating HTTP connections[/bold blue]") + if not self._setup_connections(sql_executor, substitutions): + return False + + # Step 4: Create sample data + console.print("\n[bold blue]Step 4: Creating sample table[/bold blue]") + if not self._create_sample_data(sql_executor, substitutions): + return False + + # Step 5: Create tokenization notebook + console.print("\n[bold blue]Step 5: Creating tokenization notebook[/bold blue]") + if not self._create_tokenization_notebook(notebook_manager): + return False + + # Step 6: Verify functions before tokenization + console.print("\n[bold blue]Step 6: Verifying functions[/bold blue]") + if not self._verify_functions(sql_executor, substitutions): + console.print("[yellow]⚠ Function verification failed - continuing[/yellow]") + + # Step 7: Execute tokenization (BEFORE applying column masks!) + console.print("\n[bold blue]Step 7: Running tokenization[/bold blue]") + tokenization_success = self._execute_tokenization(notebook_manager) + if not tokenization_success: + console.print("[yellow]⚠ Tokenization failed - continuing with setup[/yellow]") + + # Step 8: Apply column masks AFTER tokenization (correct order!) + console.print("\n[bold blue]Step 8: Applying column masks to tokenized data[/bold blue]") + if tokenization_success: # Only apply masks if tokenization succeeded + functions_success = self._setup_functions(sql_executor, substitutions) + if not functions_success: + console.print("[yellow]⚠ Column masks failed - continuing without them[/yellow]") + else: + console.print("[yellow]⚠ Skipping column masks - tokenization failed[/yellow]") + + # Step 9: Create dashboard + console.print("\n[bold blue]Step 9: Creating dashboard[/bold blue]") + dashboard_url = self._create_dashboard(dashboard_manager) + + # Success summary + self._print_success_summary(dashboard_url) + return True + + except Exception as e: + console.print(f"[red]Setup failed: {e}[/red]") + return False + + def _setup_unity_catalog(self, uc_manager: UnityCatalogManager) -> bool: + """Setup Unity Catalog resources.""" + catalog_name = f"{self.prefix}_catalog" + + success = uc_manager.create_catalog(catalog_name) + success &= uc_manager.create_schema(catalog_name, "default") + + return success + + def _setup_secrets(self, secrets_manager: SecretsManager) -> bool: + """Setup secret scope and secrets.""" + skyflow_config = { + "pat_token": self.config.skyflow.pat_token, + "vault_id": self.config.skyflow.vault_id, + "table": self.config.skyflow.table + } + + return secrets_manager.setup_skyflow_secrets(skyflow_config) + + def _setup_connections(self, sql_executor: SQLExecutor, substitutions: dict) -> bool: + """Setup HTTP connections using SQL.""" + # Create connections using SQL file + success = sql_executor.execute_sql_file( + "sql/setup/create_uc_connections.sql", + substitutions + ) + + if success: + # Execute additional connection setup SQL (detokenization functions) + success &= sql_executor.execute_sql_file( + "sql/setup/setup_uc_connections_api.sql", + substitutions + ) + + return success + + def _create_sample_data(self, sql_executor: SQLExecutor, substitutions: dict) -> bool: + """Create sample table and data.""" + success = sql_executor.execute_sql_file( + "sql/setup/create_sample_table.sql", + substitutions + ) + + if success: + # Check table exists first without counting rows (table might be empty initially) + table_name = f"{self.prefix}_catalog.default.{self.prefix}_customer_data" + if sql_executor.verify_table_exists(table_name): + console.print(f" ✓ Created table: {table_name}") + row_count = sql_executor.get_table_row_count(table_name) + if row_count is not None and row_count > 0: + console.print(f" ✓ Table has {row_count} rows") + else: + console.print(f" ✓ Table created (empty)") + + return success + + def _setup_functions(self, sql_executor: SQLExecutor, substitutions: dict) -> bool: + """Setup detokenization functions and column masks.""" + return sql_executor.execute_sql_file( + "sql/setup/apply_column_masks.sql", + substitutions + ) + + def _create_tokenization_notebook(self, notebook_manager: NotebookManager) -> bool: + """Create the tokenization notebook.""" + try: + return notebook_manager.setup_tokenization_notebook(self.prefix) + except Exception as e: + console.print(f"✗ Notebook creation failed: {e}") + return False + + def _verify_functions(self, sql_executor: SQLExecutor, substitutions: dict) -> bool: + """Verify Unity Catalog functions exist.""" + try: + # Add 5 second delay for function creation + console.print("Verifying function creation...") + time.sleep(5) + + console.print("Verifying Unity Catalog detokenization functions...") + success = sql_executor.execute_sql_file("sql/verify/verify_functions.sql", substitutions) + if success: + console.print("✓ Unity Catalog conditional detokenization functions verified") + return success + except Exception as e: + console.print(f"✗ Function verification failed: {e}") + return False + + def _execute_tokenization(self, notebook_manager: NotebookManager) -> bool: + """Execute the tokenization notebook.""" + try: + # Get batch size from config + batch_size = getattr(self.config.skyflow, 'batch_size', 25) + return notebook_manager.execute_tokenization_notebook(self.prefix, batch_size) + except Exception as e: + console.print(f"✗ Tokenization execution failed: {e}") + return False + + def _create_dashboard(self, dashboard_manager: DashboardManager) -> Optional[str]: + """Create the customer insights dashboard.""" + return dashboard_manager.setup_customer_insights_dashboard( + self.prefix, + self.config.databricks.warehouse_id + ) + + def _print_success_summary(self, dashboard_url: Optional[str]): + """Print success summary with resources created.""" + console.print("\n" + "="*60) + console.print(Panel.fit( + f"[bold green]✓ Setup Complete: {self.prefix}[/bold green]", + style="green" + )) + + # Resources table + table = Table(title="Resources Created") + table.add_column("Resource", style="cyan") + table.add_column("Name", style="green") + + table.add_row("Unity Catalog", f"{self.prefix}_catalog") + table.add_row("Sample Table", f"{self.prefix}_customer_data") + table.add_row("Secrets Scope", "skyflow-secrets") + table.add_row("HTTP Connection", "skyflow_conn") + table.add_row("Tokenization Notebook", f"{self.prefix}_tokenize_table") + + if dashboard_url: + table.add_row("Dashboard", f"{self.prefix}_customer_insights_dashboard") + + console.print(table) + + if dashboard_url: + console.print(f"\n[bold]Dashboard URL:[/bold] {dashboard_url}") + + console.print("\n[bold]Next Steps:[/bold]") + console.print("1. Test role-based access by running queries as different users") + console.print("2. Explore the dashboard to see detokenization in action") + console.print("3. Use the SQL functions in your own queries and applications") + + +class DestroyCommand(BaseCommand): + """Implementation of 'destroy' command.""" + + def execute(self) -> bool: + """Execute the destroy command.""" + console.print(Panel.fit( + f"Destroying Skyflow Databricks Integration: [bold]{self.prefix}[/bold]", + style="red" + )) + + try: + self.validate_environment() + + # Initialize managers + uc_manager = UnityCatalogManager(self.config.client) + secrets_manager = SecretsManager(self.config.client) + notebook_manager = NotebookManager(self.config.client) + dashboard_manager = DashboardManager(self.config.client) + sql_executor = SQLExecutor(self.config.client, self.config.databricks.warehouse_id) + + # Track successful and failed deletions for validation + successful_deletions = [] + failed_deletions = [] + + # Step 1: Delete dashboard + console.print("\n[bold blue]Step 1: Removing dashboard[/bold blue]") + dashboard_name = f"{self.prefix}_customer_insights_dashboard" + dashboard_id = dashboard_manager.find_dashboard_by_name(dashboard_name) + if dashboard_id: + if dashboard_manager.delete_dashboard(dashboard_id): + successful_deletions.append(f"Dashboard: {dashboard_name}") + # Validate deletion + if dashboard_manager.find_dashboard_by_name(dashboard_name): + failed_deletions.append(f"Dashboard: {dashboard_name} (still exists)") + else: + failed_deletions.append(f"Dashboard: {dashboard_name}") + else: + console.print(f"✓ Dashboard '{dashboard_name}' doesn't exist") + successful_deletions.append(f"Dashboard: {dashboard_name} (didn't exist)") + + # Step 2: Delete notebook + console.print("\n[bold blue]Step 2: Removing notebook[/bold blue]") + # Use Shared folder path + notebook_path = f"/Shared/{self.prefix}_tokenize_table" + if notebook_manager.delete_notebook(notebook_path): + successful_deletions.append(f"Notebook: {notebook_path}") + # Note: Validation handled in delete_notebook method + # Note: delete_notebook already handles "doesn't exist" as success + + # Step 3: Remove column masks before dropping functions/table + console.print("\n[bold blue]Step 3: Removing column masks[/bold blue]") + catalog_name = f"{self.prefix}_catalog" + substitutions = {"PREFIX": self.prefix} + if uc_manager.catalog_exists(catalog_name): + if sql_executor.execute_sql_file("sql/destroy/remove_column_masks.sql", substitutions): + successful_deletions.append("Column masks removed") + else: + console.print("✓ Column masks removal skipped (may not exist)") + successful_deletions.append("Column masks (skipped)") + else: + console.print("✓ Column masks removal skipped (catalog doesn't exist)") + successful_deletions.append("Column masks (catalog didn't exist)") + + # Step 4: Drop functions before dropping catalog + console.print("\n[bold blue]Step 4: Dropping Unity Catalog functions[/bold blue]") + catalog_name = f"{self.prefix}_catalog" + if uc_manager.catalog_exists(catalog_name): + if sql_executor.execute_sql_file("sql/destroy/drop_functions.sql", substitutions): + successful_deletions.append("Unity Catalog functions") + # Note: Function validation skipped - functions are dropped before catalog + else: + failed_deletions.append("Unity Catalog functions") + else: + console.print(f"✓ Catalog '{catalog_name}' doesn't exist, skipping function cleanup") + successful_deletions.append("Functions (catalog didn't exist)") + + # Step 5: Drop table + console.print("\n[bold blue]Step 5: Dropping sample table[/bold blue]") + if uc_manager.catalog_exists(catalog_name): + if sql_executor.execute_sql_file("sql/destroy/drop_table.sql", substitutions): + successful_deletions.append("Sample table") + # Note: Table validation skipped - table is dropped before catalog + else: + failed_deletions.append("Sample table") + else: + successful_deletions.append("Sample table (catalog didn't exist)") + + # Step 6: Delete catalog + console.print("\n[bold blue]Step 6: Removing Unity Catalog[/bold blue]") + if uc_manager.drop_catalog(catalog_name): + successful_deletions.append(f"Catalog: {catalog_name}") + # Validate catalog deletion + if uc_manager.catalog_exists(catalog_name): + failed_deletions.append(f"Catalog: {catalog_name} (still exists)") + else: + failed_deletions.append(f"Catalog: {catalog_name}") + + # Step 7: Delete connection (single consolidated Skyflow connection) + console.print("\n[bold blue]Step 7: Cleaning up connections[/bold blue]") + conn_name = "skyflow_conn" + if uc_manager.drop_connection(conn_name): + successful_deletions.append(f"Connection: {conn_name}") + # Validate connection deletion + if uc_manager.connection_exists(conn_name): + failed_deletions.append(f"Connection: {conn_name} (still exists)") + # Note: If connection doesn't exist, drop_connection already handles this gracefully + + # Step 8: Delete secrets (only if no other catalogs using them) + console.print("\n[bold blue]Step 8: Cleaning up secrets[/bold blue]") + if secrets_manager.delete_secret_scope("skyflow-secrets"): + successful_deletions.append("Secret scope: skyflow-secrets") + # Validate secret scope deletion + if secrets_manager.secret_scope_exists("skyflow-secrets"): + failed_deletions.append("Secret scope: skyflow-secrets (still exists)") + else: + failed_deletions.append("Secret scope: skyflow-secrets") + + # Print comprehensive validation summary + self._print_destroy_summary(successful_deletions, failed_deletions) + + # Return success only if all deletions succeeded and were validated + return len(failed_deletions) == 0 + + except Exception as e: + console.print(f"[red]Destroy failed: {e}[/red]") + return False + + def _print_destroy_summary(self, successful: list, failed: list): + """Print a detailed summary of destroy operation results.""" + console.print("\n" + "="*60) + console.print("[bold]Destroy Summary[/bold]") + + if successful: + console.print(f"\n[bold green]Successfully deleted ({len(successful)}):[/bold green]") + for item in successful: + console.print(f" ✓ {item}") + + if failed: + console.print(f"\n[bold red]Failed to delete ({len(failed)}):[/bold red]") + for item in failed: + console.print(f" ✗ {item}") + console.print("\n[yellow]Warning: Some resources could not be deleted or verified. Manual cleanup may be required.[/yellow]") + console.print(Panel.fit( + f"[bold red]⚠ Cleanup completed with {len(failed)} errors[/bold red]", + style="yellow" + )) + else: + console.print(Panel.fit( + f"[bold green]✓ All resources successfully deleted and validated[/bold green]", + style="green" + )) + + +class VerifyCommand(BaseCommand): + """Implementation of 'verify' command.""" + + def execute(self) -> bool: + """Execute the verify command.""" + console.print(Panel.fit( + f"Verifying Skyflow Databricks Integration: [bold]{self.prefix}[/bold]", + style="blue" + )) + + try: + self.validate_environment() + + sql_executor = SQLExecutor(self.config.client, self.config.databricks.warehouse_id) + + # Verify table exists and has data + table_name = f"{self.prefix}_catalog.default.{self.prefix}_customer_data" + table_exists = sql_executor.verify_table_exists(table_name) + + if table_exists: + row_count = sql_executor.get_table_row_count(table_name) + console.print(f"✓ Table exists with {row_count} rows") + sql_executor.show_table_sample(table_name) + else: + console.print(f"✗ Table {table_name} does not exist") + return False + + # Verify functions exist + function_name = f"{self.prefix}_catalog.default.{self.prefix}_skyflow_conditional_detokenize" + function_exists = sql_executor.verify_function_exists(function_name) + + if function_exists: + console.print(f"✓ Function {function_name} exists") + else: + console.print(f"✗ Function {function_name} does not exist") + return False + + console.print(Panel.fit( + f"[bold green]✓ Verification Complete: {self.prefix}[/bold green]", + style="green" + )) + + return True + + except Exception as e: + console.print(f"[red]Verification failed: {e}[/red]") + return False \ No newline at end of file diff --git a/skyflow_databricks/config/__init__.py b/skyflow_databricks/config/__init__.py new file mode 100644 index 0000000..6c94bef --- /dev/null +++ b/skyflow_databricks/config/__init__.py @@ -0,0 +1 @@ +# Configuration management \ No newline at end of file diff --git a/skyflow_databricks/config/config.py b/skyflow_databricks/config/config.py new file mode 100644 index 0000000..a28e554 --- /dev/null +++ b/skyflow_databricks/config/config.py @@ -0,0 +1,109 @@ +"""Main configuration class for Databricks Skyflow integration.""" + +from typing import Dict, Optional +from pydantic import BaseModel, ValidationError +from databricks.sdk import WorkspaceClient +from config.env_loader import EnvLoader + + +class DatabricksConfig(BaseModel): + """Databricks configuration model.""" + host: str + token: str + warehouse_id: str + http_path: Optional[str] = None + + +class SkyflowConfig(BaseModel): + """Skyflow configuration model.""" + vault_url: str + vault_id: str + pat_token: str + table: str + batch_size: int = 25 # Default batch size + + +class GroupConfig(BaseModel): + """Group mapping configuration.""" + plain_text_groups: str = "auditor" + masked_groups: str = "customer_service" + redacted_groups: str = "marketing" + + +class SetupConfig: + """Main configuration manager for Databricks Skyflow setup.""" + + def __init__(self, env_file: str = ".env.local"): + self.env_loader = EnvLoader(env_file) + self._databricks_config: Optional[DatabricksConfig] = None + self._skyflow_config: Optional[SkyflowConfig] = None + self._group_config: Optional[GroupConfig] = None + self._client: Optional[WorkspaceClient] = None + + @property + def databricks(self) -> DatabricksConfig: + """Get Databricks configuration.""" + if self._databricks_config is None: + config_data = self.env_loader.get_databricks_config() + try: + self._databricks_config = DatabricksConfig(**config_data) + except ValidationError as e: + raise ValueError(f"Invalid Databricks configuration: {e}") + return self._databricks_config + + @property + def skyflow(self) -> SkyflowConfig: + """Get Skyflow configuration.""" + if self._skyflow_config is None: + config_data = self.env_loader.get_skyflow_config() + try: + self._skyflow_config = SkyflowConfig(**config_data) + except ValidationError as e: + raise ValueError(f"Invalid Skyflow configuration: {e}") + return self._skyflow_config + + @property + def groups(self) -> GroupConfig: + """Get group configuration.""" + if self._group_config is None: + config_data = self.env_loader.get_group_mappings() + self._group_config = GroupConfig(**config_data) + return self._group_config + + @property + def client(self) -> WorkspaceClient: + """Get authenticated Databricks client.""" + if self._client is None: + self._client = WorkspaceClient( + host=self.databricks.host, + token=self.databricks.token + ) + return self._client + + def validate(self) -> None: + """Validate all configuration is present and correct.""" + validation = self.env_loader.validate_config() + missing = [key for key, valid in validation.items() if not valid] + + if missing: + raise ValueError(f"Missing required configuration: {', '.join(missing)}") + + # Test Databricks connection + try: + self.client.current_user.me() + except Exception as e: + raise ValueError(f"Failed to authenticate with Databricks: {e}") + + print("✓ Configuration validated successfully") + + def get_substitutions(self, prefix: str) -> Dict[str, str]: + """Get variable substitutions for SQL templates.""" + return { + "PREFIX": prefix, + "SKYFLOW_VAULT_URL": self.skyflow.vault_url, + "SKYFLOW_VAULT_ID": self.skyflow.vault_id, + "SKYFLOW_TABLE": self.skyflow.table, + "PLAIN_TEXT_GROUPS": self.groups.plain_text_groups, + "MASKED_GROUPS": self.groups.masked_groups, + "REDACTED_GROUPS": self.groups.redacted_groups + } \ No newline at end of file diff --git a/skyflow_databricks/config/env_loader.py b/skyflow_databricks/config/env_loader.py new file mode 100644 index 0000000..e4a017f --- /dev/null +++ b/skyflow_databricks/config/env_loader.py @@ -0,0 +1,75 @@ +"""Environment configuration loader for Databricks Skyflow integration.""" + +import os +from pathlib import Path +from typing import Dict, Optional, Any +from dotenv import load_dotenv + + +class EnvLoader: + """Loads and processes environment variables from .env.local file.""" + + def __init__(self, env_file: str = ".env.local"): + self.env_file = env_file + self._load_env_file() + + def _load_env_file(self) -> None: + """Load environment file if it exists.""" + env_path = Path(self.env_file) + if env_path.exists(): + print(f"Loading configuration from {self.env_file}...") + load_dotenv(env_path) + else: + print(f"Warning: {self.env_file} not found - using environment variables only") + + def get_databricks_config(self) -> Dict[str, Optional[str]]: + """Extract Databricks configuration from environment.""" + hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME") + host = f"https://{hostname}" if hostname else None + + # Extract warehouse ID from HTTP path + http_path = os.getenv("DATABRICKS_HTTP_PATH") + warehouse_id = None + if http_path and "/warehouses/" in http_path: + warehouse_id = http_path.split("/warehouses/")[-1] + + return { + "host": host, + "token": os.getenv("DATABRICKS_PAT_TOKEN"), + "warehouse_id": warehouse_id, + "http_path": http_path + } + + def get_skyflow_config(self) -> Dict[str, Any]: + """Extract Skyflow configuration from environment.""" + return { + "vault_url": os.getenv("SKYFLOW_VAULT_URL"), + "vault_id": os.getenv("SKYFLOW_VAULT_ID"), + "pat_token": os.getenv("SKYFLOW_PAT_TOKEN"), + "table": os.getenv("SKYFLOW_TABLE"), + "table_column": os.getenv("SKYFLOW_TABLE_COLUMN", "pii_values"), + "batch_size": int(os.getenv("SKYFLOW_BATCH_SIZE", "25")) + } + + def get_group_mappings(self) -> Dict[str, str]: + """Extract group mappings for detokenization.""" + return { + "plain_text_groups": os.getenv("PLAIN_TEXT_GROUPS", "auditor"), + "masked_groups": os.getenv("MASKED_GROUPS", "customer_service"), + "redacted_groups": os.getenv("REDACTED_GROUPS", "marketing") + } + + def validate_config(self) -> Dict[str, bool]: + """Validate that required configuration is present.""" + databricks = self.get_databricks_config() + skyflow = self.get_skyflow_config() + + return { + "databricks_host": databricks["host"] is not None, + "databricks_token": databricks["token"] is not None, + "warehouse_id": databricks["warehouse_id"] is not None, + "skyflow_vault_url": skyflow["vault_url"] is not None, + "skyflow_vault_id": skyflow["vault_id"] is not None, + "skyflow_pat_token": skyflow["pat_token"] is not None, + "skyflow_table": skyflow["table"] is not None + } \ No newline at end of file diff --git a/skyflow_databricks/databricks_ops/__init__.py b/skyflow_databricks/databricks_ops/__init__.py new file mode 100644 index 0000000..e2b2c22 --- /dev/null +++ b/skyflow_databricks/databricks_ops/__init__.py @@ -0,0 +1 @@ +# Databricks SDK operations \ No newline at end of file diff --git a/skyflow_databricks/databricks_ops/client.py b/skyflow_databricks/databricks_ops/client.py new file mode 100644 index 0000000..a0263ab --- /dev/null +++ b/skyflow_databricks/databricks_ops/client.py @@ -0,0 +1,68 @@ +"""Databricks SDK client wrapper with error handling.""" + +import time +from typing import Optional, Dict, Any +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import DatabricksError +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +console = Console() + + +class DatabricksClientWrapper: + """Enhanced Databricks client with retry logic and better error handling.""" + + def __init__(self, client: WorkspaceClient): + self.client = client + + def wait_for_completion(self, operation_name: str, check_func, timeout: int = 300) -> bool: + """Wait for an operation to complete with progress indication.""" + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task(f"Waiting for {operation_name}...", total=None) + + start_time = time.time() + while time.time() - start_time < timeout: + try: + if check_func(): + progress.update(task, description=f"✓ {operation_name} completed") + return True + except Exception: + pass # Continue waiting + + time.sleep(2) + + progress.update(task, description=f"✗ {operation_name} timed out") + return False + + def execute_with_retry(self, operation, max_retries: int = 3, delay: int = 2) -> Any: + """Execute operation with retry logic.""" + last_error = None + + for attempt in range(max_retries): + try: + return operation() + except DatabricksError as e: + last_error = e + if attempt < max_retries - 1: + console.print(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 # Exponential backoff + else: + console.print(f"All {max_retries} attempts failed") + + raise last_error + + def check_resource_exists(self, resource_type: str, check_func) -> bool: + """Check if a resource exists without throwing errors.""" + try: + check_func() + return True + except DatabricksError as e: + if "does not exist" in str(e) or "not found" in str(e).lower(): + return False + raise # Re-raise if it's not a "not found" error \ No newline at end of file diff --git a/skyflow_databricks/databricks_ops/dashboards.py b/skyflow_databricks/databricks_ops/dashboards.py new file mode 100644 index 0000000..8f6f13f --- /dev/null +++ b/skyflow_databricks/databricks_ops/dashboards.py @@ -0,0 +1,189 @@ +"""Dashboard operations - replaces bash dashboard creation.""" + +import json +from pathlib import Path +from typing import Dict, Optional, Any +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.dashboards import Dashboard +from databricks.sdk.errors import DatabricksError +from rich.console import Console +from databricks_ops.client import DatabricksClientWrapper + +console = Console() + + +class DashboardManager: + """Manages Databricks Lakeview dashboards.""" + + def __init__(self, client: WorkspaceClient): + self.client = client + self.wrapper = DatabricksClientWrapper(client) + + def create_dashboard_from_file(self, local_path: str, dashboard_name: str, + warehouse_id: str, substitutions: Optional[Dict[str, str]] = None) -> Optional[str]: + """Create a Lakeview dashboard from local JSON file using SDK methods.""" + try: + path = Path(local_path) + if not path.exists(): + console.print(f"✗ Dashboard file not found: {local_path}") + return None + + # Read dashboard definition + with open(path, 'r') as f: + dashboard_content = f.read() + + # Apply substitutions to dashboard content + if substitutions: + for key, value in substitutions.items(): + dashboard_content = dashboard_content.replace(f"${{{key}}}", str(value)) + + # Parse JSON to validate + try: + dashboard_json = json.loads(dashboard_content) + except json.JSONDecodeError as e: + console.print(f"✗ Invalid JSON in dashboard file: {e}") + return None + + # Delete existing dashboard if it exists (using SDK) + self._delete_existing_dashboard_sdk(dashboard_name) + + # Create dashboard using SDK + dashboard = Dashboard( + display_name=dashboard_name, + warehouse_id=warehouse_id, + serialized_dashboard=json.dumps(dashboard_json), # JSON encode dashboard content + parent_path="/Shared" + ) + + def create_dashboard(): + return self.client.lakeview.create(dashboard) + + result = self.wrapper.execute_with_retry(create_dashboard) + + if result.dashboard_id: + console.print(f"✓ Created dashboard: {dashboard_name}") + host = self.client.config.host + dashboard_url = f"{host}/sql/dashboardsv3/{result.dashboard_id}" + console.print(f" Dashboard URL: {dashboard_url}") + return dashboard_url + else: + console.print(f"✗ Could not extract dashboard ID from response") + return None + + except Exception as e: + console.print(f"✗ Error creating dashboard: {e}") + return None + + def _delete_existing_dashboard_sdk(self, dashboard_name: str) -> None: + """Delete existing dashboard with the same name using SDK.""" + try: + # Find existing dashboard by name using existing SDK method + existing_id = self.find_dashboard_by_name(dashboard_name) + if existing_id: + console.print(f"Deleting existing dashboard: {dashboard_name}") + self.delete_dashboard(existing_id) + except Exception: + # Don't fail if we can't delete existing - just continue + pass + + def delete_dashboard(self, dashboard_id: str) -> bool: + """Delete a dashboard by ID.""" + try: + def delete(): + return self.client.lakeview.trash(dashboard_id) + + self.wrapper.execute_with_retry(delete) + console.print(f"✓ Deleted dashboard: {dashboard_id}") + return True + + except DatabricksError as e: + if "not found" in str(e).lower(): + console.print(f"✓ Dashboard {dashboard_id} doesn't exist") + return True + console.print(f"✗ Failed to delete dashboard {dashboard_id}: {e}") + return False + + def list_dashboards(self) -> list: + """List all dashboards.""" + try: + dashboards = self.client.lakeview.list() + return [ + { + "id": d.dashboard_id, + "name": d.display_name, + "warehouse_id": d.warehouse_id + } + for d in dashboards + ] + except DatabricksError as e: + console.print(f"✗ Failed to list dashboards: {e}") + return [] + + def find_dashboard_by_name(self, name: str) -> Optional[str]: + """Find dashboard ID by name.""" + dashboards = self.list_dashboards() + for dashboard in dashboards: + if dashboard["name"] == name: + return dashboard["id"] + return None + + def setup_customer_insights_dashboard(self, prefix: str, warehouse_id: str) -> Optional[str]: + """Setup the customer insights dashboard for the specified prefix.""" + # Get template path relative to this module + template_dir = Path(__file__).parent.parent / "templates" + dashboard_file = template_dir / "dashboards" / "customer_insights_dashboard.lvdash.json" + dashboard_name = f"{prefix}_customer_insights_dashboard" + + # Prepare substitutions + substitutions = { + "PREFIX": prefix, + f"{prefix.upper()}_CATALOG": f"{prefix}_catalog", + f"{prefix.upper()}_CUSTOMER_DATASET": f"{prefix}_customer_data" + } + + # Check if dashboard already exists + existing_id = self.find_dashboard_by_name(dashboard_name) + if existing_id: + console.print(f"✓ Dashboard '{dashboard_name}' already exists") + dashboard_url = f"{self.client.config.host}/sql/dashboardsv3/{existing_id}" + console.print(f" Dashboard URL: {dashboard_url}") + return dashboard_url + + # Create new dashboard + return self.create_dashboard_from_file( + str(dashboard_file), + dashboard_name, + warehouse_id, + substitutions + ) + + def update_dashboard_warehouse(self, dashboard_id: str, warehouse_id: str) -> bool: + """Update the warehouse used by a dashboard.""" + try: + def update(): + return self.client.lakeview.update( + dashboard_id=dashboard_id, + warehouse_id=warehouse_id + ) + + self.wrapper.execute_with_retry(update) + console.print(f"✓ Updated dashboard warehouse to: {warehouse_id}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to update dashboard warehouse: {e}") + return False + + def publish_dashboard(self, dashboard_id: str) -> bool: + """Publish a dashboard.""" + try: + def publish(): + return self.client.lakeview.publish(dashboard_id) + + self.wrapper.execute_with_retry(publish) + console.print(f"✓ Published dashboard: {dashboard_id}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to publish dashboard: {e}") + return False \ No newline at end of file diff --git a/skyflow_databricks/databricks_ops/notebooks.py b/skyflow_databricks/databricks_ops/notebooks.py new file mode 100644 index 0000000..223b2be --- /dev/null +++ b/skyflow_databricks/databricks_ops/notebooks.py @@ -0,0 +1,242 @@ +"""Notebook operations - replaces bash notebook creation/execution.""" + +import time +from pathlib import Path +from typing import Optional, List +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import Language, ObjectType, ImportFormat +from databricks.sdk.service.jobs import SubmitTask, NotebookTask, Source +from databricks.sdk.errors import DatabricksError +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from databricks_ops.client import DatabricksClientWrapper + +console = Console() + + +class NotebookManager: + """Manages Databricks notebooks and job execution.""" + + def __init__(self, client: WorkspaceClient): + self.client = client + self.wrapper = DatabricksClientWrapper(client) + + def create_notebook_from_file(self, local_path: str, workspace_path: str) -> bool: + """Create a notebook in workspace from local file.""" + try: + path = Path(local_path) + if not path.exists(): + console.print(f"✗ Notebook file not found: {local_path}") + return False + + # Read notebook content + with open(path, 'r') as f: + if path.suffix == '.ipynb': + # Jupyter notebook + notebook_content = f.read() + language = Language.PYTHON + else: + # Assume Python script + content = f.read() + # Convert to simple notebook format for upload + notebook_content = content + language = Language.PYTHON + + def upload_notebook(): + return self.client.workspace.upload( + path=workspace_path, + content=notebook_content.encode('utf-8'), + language=language, + overwrite=True, + format=ImportFormat.JUPYTER if path.suffix == '.ipynb' else ImportFormat.SOURCE + ) + + self.wrapper.execute_with_retry(upload_notebook) + console.print(f"✓ Created notebook: {workspace_path}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to create notebook {workspace_path}: {e}") + return False + + def run_notebook_job(self, notebook_path: str, table_name: str, pii_columns: str, + batch_size: int = 25, timeout_minutes: int = 15) -> bool: + """Run a notebook as a serverless job using SDK methods.""" + try: + # Use serverless compute with multi-task format + run_name = f"Serverless_Tokenize_{table_name.replace('.', '_')}_{int(time.time())}" + + console.print(f"Running notebook: {notebook_path}") + console.print(f"Batch size: {batch_size}") + + # Create submit task using SDK classes + submit_task = SubmitTask( + task_key="tokenize_task", + notebook_task=NotebookTask( + notebook_path=notebook_path, + source=Source.WORKSPACE, + base_parameters={ + "table_name": table_name, + "pii_columns": pii_columns, + "batch_size": str(batch_size) + } + ), + timeout_seconds=1800 # 30 minutes + ) + + # Submit job using SDK method + def submit_job(): + return self.client.jobs.submit( + run_name=run_name, + tasks=[submit_task] + ) + + waiter = self.wrapper.execute_with_retry(submit_job) + run_id = waiter.run_id + console.print(f"✓ Started notebook run with ID: {run_id}") + + # Extract workspace ID for live logs URL + workspace_id = self._extract_workspace_id() + if workspace_id: + host = self.client.config.host + console.print(f"View live logs: {host}/jobs/runs/{run_id}?o={workspace_id}") + + # Wait for completion with progress + console.print("Waiting for tokenization to complete...") + return self._monitor_job_execution_sdk(run_id, timeout_minutes) + + except Exception as e: + console.print(f"✗ Failed to run notebook job: {e}") + return False + + def _extract_workspace_id(self) -> Optional[str]: + """Extract workspace ID from hostname for live logs URL.""" + try: + import re + host = self.client.config.host + # Extract from pattern: https://dbc-{workspace_id}-{suffix}.cloud.databricks.com + match = re.search(r'dbc-([a-f0-9]+)-', host) + return match.group(1) if match else None + except: + return None + + def _monitor_job_execution_sdk(self, run_id: int, timeout_minutes: int) -> bool: + """Monitor job execution using SDK methods with same polling behavior.""" + max_wait_seconds = timeout_minutes * 60 + wait_time = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Tokenization in progress...", total=None) + + while wait_time < max_wait_seconds: + try: + # Use SDK method to get run status + def get_run_status(): + return self.client.jobs.get_run(run_id) + + run = self.wrapper.execute_with_retry(get_run_status) + state = run.state.life_cycle_state.value if run.state and run.state.life_cycle_state else 'UNKNOWN' + + if state == "TERMINATED": + result_state = run.state.result_state.value if run.state and run.state.result_state else 'UNKNOWN' + if result_state == "SUCCESS": + progress.update(task, description="✅ Tokenization completed successfully") + console.print("✅ Tokenization completed successfully") + return True + else: + error_msg = run.state.state_message if run.state and run.state.state_message else f"Failed with result: {result_state}" + progress.update(task, description="❌ Tokenization failed") + console.print(f"❌ Tokenization failed with result: {result_state}") + console.print(f"Error: {error_msg}") + return False + + elif state in ["INTERNAL_ERROR", "FAILED", "TIMEDOUT", "CANCELED", "SKIPPED"]: + progress.update(task, description=f"❌ Tokenization run failed: {state}") + console.print(f"❌ Tokenization run failed with state: {state}") + return False + + else: + # Job still running - update progress + progress.update(task, description=f"Tokenization in progress... (state: {state})") + time.sleep(30) # Poll every 30 seconds + wait_time += 30 + + except Exception as e: + console.print(f"Error monitoring job: {e}") + time.sleep(30) + wait_time += 30 + + # Timeout reached + progress.update(task, description="❌ Tokenization timed out") + console.print(f"❌ Tokenization timed out after {timeout_minutes} minutes") + return False + + def delete_notebook(self, workspace_path: str) -> bool: + """Delete a notebook from workspace.""" + try: + # Check if notebook exists first to avoid retry confusion + if not self.notebook_exists(workspace_path): + console.print(f"✓ Notebook {workspace_path} doesn't exist") + return True + + def delete(): + return self.client.workspace.delete(workspace_path, recursive=True) + + self.wrapper.execute_with_retry(delete) + console.print(f"✓ Deleted notebook: {workspace_path}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to delete notebook {workspace_path}: {e}") + return False + + def list_notebooks(self, workspace_path: str) -> List[str]: + """List notebooks in a workspace directory.""" + try: + objects = self.client.workspace.list(workspace_path) + return [ + obj.path for obj in objects + if obj.object_type == ObjectType.NOTEBOOK + ] + except DatabricksError: + return [] + + def notebook_exists(self, workspace_path: str) -> bool: + """Check if a notebook exists in workspace.""" + try: + obj = self.client.workspace.get_status(workspace_path) + return obj.object_type == ObjectType.NOTEBOOK + except DatabricksError: + return False + + def setup_tokenization_notebook(self, prefix: str) -> bool: + """Setup and execute the tokenization notebook.""" + # Get template path relative to this module + template_dir = Path(__file__).parent.parent / "templates" + local_notebook_path = template_dir / "notebooks" / "notebook_tokenize_table.ipynb" + # Use Shared folder path + workspace_path = f"/Shared/{prefix}_tokenize_table" + + # Create the notebook + if not self.create_notebook_from_file(str(local_notebook_path), workspace_path): + return False + + return True + + def execute_tokenization_notebook(self, prefix: str, batch_size: int = 25) -> bool: + """Execute the tokenization notebook with parameters.""" + workspace_path = f"/Shared/{prefix}_tokenize_table" + table_name = f"{prefix}_catalog.default.{prefix}_customer_data" + pii_columns = "first_name,last_name,email,phone_number,address,date_of_birth" # PII columns to tokenize + + console.print("Tokenizing PII data in sample table...") + return self.run_notebook_job( + notebook_path=workspace_path, + table_name=table_name, + pii_columns=pii_columns, + batch_size=batch_size + ) \ No newline at end of file diff --git a/skyflow_databricks/databricks_ops/secrets.py b/skyflow_databricks/databricks_ops/secrets.py new file mode 100644 index 0000000..66e1a5e --- /dev/null +++ b/skyflow_databricks/databricks_ops/secrets.py @@ -0,0 +1,134 @@ +"""Secrets management - replaces bash secrets setup functionality.""" + +from typing import Dict, List +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import ScopeBackendType +from databricks.sdk.errors import DatabricksError +from rich.console import Console +from databricks_ops.client import DatabricksClientWrapper + +console = Console() + + +class SecretsManager: + """Manages Databricks secrets and secret scopes.""" + + def __init__(self, client: WorkspaceClient): + self.client = client + self.wrapper = DatabricksClientWrapper(client) + + def create_secret_scope(self, scope_name: str, backend_type: str = "DATABRICKS") -> bool: + """Create a secret scope.""" + try: + # Check if scope already exists + existing_scopes = self.client.secrets.list_scopes() + if any(scope.name == scope_name for scope in existing_scopes): + console.print(f"✓ Secret scope '{scope_name}' already exists") + return True + + # Map backend type + backend = ScopeBackendType.DATABRICKS + if backend_type.upper() == "AZURE_KEYVAULT": + backend = ScopeBackendType.AZURE_KEYVAULT + + def create_scope(): + return self.client.secrets.create_scope( + scope=scope_name, + scope_backend_type=backend + ) + + self.wrapper.execute_with_retry(create_scope) + console.print(f"✓ Created secret scope: {scope_name}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to create secret scope {scope_name}: {e}") + return False + + def put_secret(self, scope_name: str, key: str, value: str) -> bool: + """Put a secret value in the specified scope.""" + try: + def put_secret_value(): + return self.client.secrets.put_secret( + scope=scope_name, + key=key, + string_value=value + ) + + self.wrapper.execute_with_retry(put_secret_value) + console.print(f"✓ Set secret: {scope_name}/{key}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to set secret {scope_name}/{key}: {e}") + return False + + def delete_secret_scope(self, scope_name: str) -> bool: + """Delete a secret scope and all its secrets.""" + try: + # Check if scope exists + existing_scopes = self.client.secrets.list_scopes() + if not any(scope.name == scope_name for scope in existing_scopes): + console.print(f"✓ Secret scope '{scope_name}' doesn't exist") + return True + + def delete_scope(): + return self.client.secrets.delete_scope(scope_name) + + self.wrapper.execute_with_retry(delete_scope) + console.print(f"✓ Deleted secret scope: {scope_name}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to delete secret scope {scope_name}: {e}") + return False + + def setup_skyflow_secrets(self, skyflow_config: Dict[str, str]) -> bool: + """Setup all Skyflow-related secrets.""" + scope_name = "skyflow-secrets" + + # Create the scope + if not self.create_secret_scope(scope_name): + return False + + # Secret mappings + secrets = { + "skyflow_pat_token": skyflow_config["pat_token"], + "skyflow_vault_id": skyflow_config["vault_id"], + "skyflow_table": skyflow_config["table"], + "skyflow_table_column": skyflow_config.get("table_column", "pii_values") # Skyflow table column name + } + + success = True + for key, value in secrets.items(): + if not self.put_secret(scope_name, key, value): + success = False + + return success + + def list_secrets_in_scope(self, scope_name: str) -> List[str]: + """List all secret keys in a scope.""" + try: + secrets = self.client.secrets.list_secrets(scope_name) + return [secret.key for secret in secrets] + except DatabricksError: + return [] + + def verify_secrets(self, scope_name: str, required_keys: List[str]) -> bool: + """Verify that all required secrets exist in the scope.""" + existing_keys = self.list_secrets_in_scope(scope_name) + missing_keys = [key for key in required_keys if key not in existing_keys] + + if missing_keys: + console.print(f"✗ Missing secrets in {scope_name}: {', '.join(missing_keys)}") + return False + + console.print(f"✓ All required secrets exist in {scope_name}") + return True + + def secret_scope_exists(self, scope_name: str) -> bool: + """Check if a secret scope exists.""" + return self.wrapper.check_resource_exists( + "secret scope", + lambda: self.client.secrets.list_secrets(scope_name) + ) \ No newline at end of file diff --git a/skyflow_databricks/databricks_ops/sql.py b/skyflow_databricks/databricks_ops/sql.py new file mode 100644 index 0000000..a68327e --- /dev/null +++ b/skyflow_databricks/databricks_ops/sql.py @@ -0,0 +1,173 @@ +"""SQL execution - replaces bash execute_sql functionality.""" + +import time +from pathlib import Path +from typing import Dict, Optional, List, Any +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.sql import StatementState, StatementResponse +from databricks.sdk.errors import DatabricksError +from rich.console import Console +from rich.table import Table +from databricks_ops.client import DatabricksClientWrapper + +console = Console() + + +class SQLExecutor: + """Executes SQL files and statements against Databricks.""" + + def __init__(self, client: WorkspaceClient, warehouse_id: str): + self.client = client + self.warehouse_id = warehouse_id + self.wrapper = DatabricksClientWrapper(client) + + def apply_substitutions(self, sql: str, substitutions: Dict[str, str]) -> str: + """Apply variable substitutions to SQL content.""" + if not substitutions: + return sql + + for key, value in substitutions.items(): + sql = sql.replace(f"${{{key}}}", str(value)) + + return sql + + def execute_statement(self, sql: str, timeout: int = 300) -> Optional[StatementResponse]: + """Execute a single SQL statement.""" + try: + def execute(): + return self.client.statement_execution.execute_statement( + warehouse_id=self.warehouse_id, + statement=sql, + wait_timeout="30s" + ) + + response = self.wrapper.execute_with_retry(execute) + + # Wait for completion if needed + if response.status.state in [StatementState.PENDING, StatementState.RUNNING]: + def check_completion(): + result = self.client.statement_execution.get_statement(response.statement_id) + return result.status.state in [StatementState.SUCCEEDED, StatementState.FAILED, StatementState.CANCELED] + + if self.wrapper.wait_for_completion("SQL execution", check_completion, timeout): + response = self.client.statement_execution.get_statement(response.statement_id) + + if response.status.state == StatementState.SUCCEEDED: + return response + else: + error_msg = response.status.error.message if response.status.error else "Unknown error" + console.print(f"✗ SQL execution failed: {error_msg}") + return None + + except DatabricksError as e: + console.print(f"✗ SQL execution error: {e}") + return None + + def execute_sql_file(self, file_path: str, substitutions: Optional[Dict[str, str]] = None) -> bool: + """Execute SQL from a file with variable substitutions.""" + # If path is relative, look in templates directory + if not Path(file_path).is_absolute(): + template_dir = Path(__file__).parent.parent / "templates" + path = template_dir / file_path + else: + path = Path(file_path) + + if not path.exists(): + console.print(f"✗ SQL file not found: {path}") + return False + + console.print(f"Executing SQL file: {path.name}") + + try: + with open(path, 'r') as f: + sql_content = f.read() + + # Apply substitutions + if substitutions: + sql_content = self.apply_substitutions(sql_content, substitutions) + + # Split into individual statements (simple approach) + statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()] + + success = True + for i, statement in enumerate(statements): + console.print(f" Executing statement {i+1}/{len(statements)}") + result = self.execute_statement(statement) + + if result is None: + success = False + console.print(f"✗ Failed to execute statement {i+1}") + break + else: + console.print(f" ✓ Statement {i+1} completed") + + if success: + console.print(f"✓ Successfully executed {path.name}") + + return success + + except Exception as e: + console.print(f"✗ Error reading/executing {file_path}: {e}") + return False + + def execute_query_with_results(self, sql: str, max_rows: int = 100) -> Optional[List[Dict[str, Any]]]: + """Execute a query and return results.""" + response = self.execute_statement(sql) + + if response and response.result and response.result.data_array: + # Convert to list of dictionaries + columns = [col.name for col in response.manifest.schema.columns] if response.manifest else [] + results = [] + + for row in response.result.data_array[:max_rows]: + row_dict = {columns[i]: row[i] for i in range(len(columns))} if columns else {} + results.append(row_dict) + + return results + + return None + + def verify_table_exists(self, table_name: str) -> bool: + """Check if a table exists.""" + sql = f"DESCRIBE TABLE {table_name}" + result = self.execute_statement(sql) + return result is not None + + def verify_function_exists(self, function_name: str) -> bool: + """Check if a function exists.""" + sql = f"DESCRIBE FUNCTION {function_name}" + result = self.execute_statement(sql) + return result is not None + + def get_table_row_count(self, table_name: str) -> Optional[int]: + """Get row count for a table.""" + sql = f"SELECT COUNT(*) as count FROM {table_name}" + results = self.execute_query_with_results(sql) + + if results and len(results) > 0: + count_value = results[0].get('count', 0) + # Convert to int if it's a string + return int(count_value) if count_value is not None else 0 + + return None + + def show_table_sample(self, table_name: str, limit: int = 5) -> None: + """Display a sample of table data.""" + sql = f"SELECT * FROM {table_name} LIMIT {limit}" + results = self.execute_query_with_results(sql, max_rows=limit) + + if results: + table = Table(title=f"Sample data from {table_name}") + + # Add columns + if results: + for column in results[0].keys(): + table.add_column(column) + + # Add rows + for row in results: + table.add_row(*[str(value) for value in row.values()]) + + console.print(table) + else: + console.print(f"No data found in {table_name}") \ No newline at end of file diff --git a/skyflow_databricks/databricks_ops/unity_catalog.py b/skyflow_databricks/databricks_ops/unity_catalog.py new file mode 100644 index 0000000..72337f4 --- /dev/null +++ b/skyflow_databricks/databricks_ops/unity_catalog.py @@ -0,0 +1,174 @@ +"""Unity Catalog operations - replaces bash UC connection setup.""" + +from typing import Dict, Optional, List +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.catalog import ConnectionType +from databricks.sdk.errors import DatabricksError +from rich.console import Console +from databricks_ops.client import DatabricksClientWrapper + +console = Console() + + +class UnityCatalogManager: + """Manages Unity Catalog resources for Skyflow integration.""" + + def __init__(self, client: WorkspaceClient): + self.client = client + self.wrapper = DatabricksClientWrapper(client) + + def create_http_connection(self, name: str, host: str, base_path: str, + secret_scope: str, secret_key: str) -> bool: + """Create Unity Catalog HTTP connection.""" + try: + # Check if connection already exists + if self.wrapper.check_resource_exists( + "connection", + lambda: self.client.connections.get(name) + ): + console.print(f"✓ Connection '{name}' already exists") + return True + + # Create the connection + def create_conn(): + return self.client.connections.create( + name=name, + connection_type=ConnectionType.HTTP, + options={ + "host": host, + "port": "443", + "base_path": base_path + }, + properties={ + "bearer_token": f"secret('{secret_scope}', '{secret_key}')" + } + ) + + self.wrapper.execute_with_retry(create_conn) + console.print(f"✓ Created HTTP connection: {name}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to create connection {name}: {e}") + return False + + def create_catalog(self, name: str, comment: Optional[str] = None) -> bool: + """Create Unity Catalog catalog.""" + try: + if self.wrapper.check_resource_exists( + "catalog", + lambda: self.client.catalogs.get(name) + ): + console.print(f"✓ Catalog '{name}' already exists") + return True + + def create_cat(): + return self.client.catalogs.create( + name=name, + comment=comment or f"Skyflow integration catalog - {name}" + ) + + self.wrapper.execute_with_retry(create_cat) + console.print(f"✓ Created catalog: {name}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to create catalog {name}: {e}") + return False + + def create_schema(self, catalog_name: str, schema_name: str = "default") -> bool: + """Create schema in Unity Catalog.""" + full_name = f"{catalog_name}.{schema_name}" + + try: + if self.wrapper.check_resource_exists( + "schema", + lambda: self.client.schemas.get(full_name) + ): + console.print(f"✓ Schema '{full_name}' already exists") + return True + + def create_sch(): + return self.client.schemas.create( + name=schema_name, + catalog_name=catalog_name + ) + + self.wrapper.execute_with_retry(create_sch) + console.print(f"✓ Created schema: {full_name}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to create schema {full_name}: {e}") + return False + + def drop_catalog(self, name: str, force: bool = True) -> bool: + """Drop Unity Catalog catalog and all contents.""" + try: + if not self.wrapper.check_resource_exists( + "catalog", + lambda: self.client.catalogs.get(name) + ): + console.print(f"✓ Catalog '{name}' doesn't exist") + return True + + def drop_cat(): + return self.client.catalogs.delete(name, force=force) + + self.wrapper.execute_with_retry(drop_cat) + console.print(f"✓ Dropped catalog: {name}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to drop catalog {name}: {e}") + return False + + def drop_connection(self, name: str) -> bool: + """Drop Unity Catalog HTTP connection.""" + try: + if not self.wrapper.check_resource_exists( + "connection", + lambda: self.client.connections.get(name) + ): + console.print(f"✓ Connection '{name}' doesn't exist") + return True + + def drop_conn(): + return self.client.connections.delete(name) + + self.wrapper.execute_with_retry(drop_conn) + console.print(f"✓ Dropped connection: {name}") + return True + + except DatabricksError as e: + console.print(f"✗ Failed to drop connection {name}: {e}") + return False + + def setup_skyflow_connections(self, vault_url: str, vault_id: str) -> bool: + """Setup both Skyflow HTTP connections.""" + success = True + + # Main Skyflow connection + success &= self.create_http_connection( + name="skyflow_conn", + host=vault_url.replace("https://", "").replace("http://", ""), + base_path="/v1/vaults", + secret_scope="skyflow-secrets", + secret_key="skyflow_pat_token" + ) + + return success + + def catalog_exists(self, name: str) -> bool: + """Check if a catalog exists.""" + return self.wrapper.check_resource_exists( + "catalog", + lambda: self.client.catalogs.get(name) + ) + + def connection_exists(self, name: str) -> bool: + """Check if a connection exists.""" + return self.wrapper.check_resource_exists( + "connection", + lambda: self.client.connections.get(name) + ) \ No newline at end of file diff --git a/dashboards/customer_insights_dashboard.lvdash.json b/skyflow_databricks/templates/dashboards/customer_insights_dashboard.lvdash.json similarity index 100% rename from dashboards/customer_insights_dashboard.lvdash.json rename to skyflow_databricks/templates/dashboards/customer_insights_dashboard.lvdash.json diff --git a/skyflow_databricks/templates/notebooks/notebook_tokenize_table.ipynb b/skyflow_databricks/templates/notebooks/notebook_tokenize_table.ipynb new file mode 100644 index 0000000..ff17390 --- /dev/null +++ b/skyflow_databricks/templates/notebooks/notebook_tokenize_table.ipynb @@ -0,0 +1,19 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "cell-0", + "metadata": {}, + "outputs": [], + "source": "# Unity Catalog-aware serverless tokenization notebook \n# Uses dbutils.secrets.get() + UC HTTP connections for serverless compatibility\nimport json\nimport os\nfrom pyspark.sql import SparkSession\nfrom pyspark.dbutils import DBUtils\n\n# Initialize Spark session optimized for serverless compute\nspark = SparkSession.builder \\\n .appName(\"SkyflowTokenization\") \\\n .config(\"spark.databricks.cluster.profile\", \"serverless\") \\\n .config(\"spark.databricks.delta.autoCompact.enabled\", \"true\") \\\n .config(\"spark.sql.adaptive.enabled\", \"true\") \\\n .config(\"spark.sql.adaptive.coalescePartitions.enabled\", \"true\") \\\n .getOrCreate()\n \ndbutils = DBUtils(spark)\n\nprint(f\"✓ Running on Databricks serverless compute\")\nprint(f\"✓ Spark version: {spark.version}\")\n\n# Performance configuration\nMAX_MERGE_BATCH_SIZE = 10000 # Maximum records per MERGE operation\nCOLLECT_BATCH_SIZE = 1000 # Maximum records to collect() from Spark at once\n\n# Define widgets to receive input parameters\ndbutils.widgets.text(\"table_name\", \"\")\ndbutils.widgets.text(\"pii_columns\", \"\")\ndbutils.widgets.text(\"batch_size\", \"\") # Skyflow API batch size\n\n# Read widget values\ntable_name = dbutils.widgets.get(\"table_name\")\npii_columns = dbutils.widgets.get(\"pii_columns\").split(\",\")\nSKYFLOW_BATCH_SIZE = int(dbutils.widgets.get(\"batch_size\"))\n\nif not table_name or not pii_columns:\n raise ValueError(\"Both 'table_name' and 'pii_columns' must be provided.\")\n\nprint(f\"Tokenizing table: {table_name}\")\nprint(f\"PII columns: {', '.join(pii_columns)}\")\nprint(f\"Skyflow API batch size: {SKYFLOW_BATCH_SIZE}\")\nprint(f\"MERGE batch size limit: {MAX_MERGE_BATCH_SIZE:,} records\")\nprint(f\"Collect batch size: {COLLECT_BATCH_SIZE:,} records\")\n\n# Extract catalog and schema from table name if fully qualified\nif '.' in table_name:\n parts = table_name.split('.')\n if len(parts) == 3: # catalog.schema.table\n catalog_name = parts[0]\n schema_name = parts[1]\n table_name_only = parts[2]\n \n # Set the catalog and schema context for this session\n print(f\"Setting catalog context to: {catalog_name}\")\n spark.sql(f\"USE CATALOG {catalog_name}\")\n spark.sql(f\"USE SCHEMA {schema_name}\")\n \n # Use the simple table name for queries since context is set\n table_name = table_name_only\n print(f\"✓ Catalog context set, using table name: {table_name}\")\n\nprint(\"✓ Using dbutils.secrets.get() + UC HTTP connections for serverless compatibility\")\n\ndef tokenize_column_values(column_name, values):\n \"\"\"\n Tokenize a list of PII values using Unity Catalog HTTP connection.\n Uses dbutils.secrets.get() and http_request() for serverless compatibility.\n Returns list of tokens in same order as input values.\n \"\"\"\n if not values:\n return []\n \n # Get secrets using dbutils (works in serverless)\n table_column = dbutils.secrets.get(\"skyflow-secrets\", \"skyflow_table_column\")\n skyflow_table = dbutils.secrets.get(\"skyflow-secrets\", \"skyflow_table\")\n \n # Create records for each value\n skyflow_records = [{\n \"fields\": {table_column: str(value)}\n } for value in values if value is not None]\n\n # Create Skyflow tokenization payload\n payload = {\n \"records\": skyflow_records,\n \"tokenization\": True\n }\n\n print(f\" Tokenizing {len(skyflow_records)} values for {column_name}\")\n \n # Use Unity Catalog HTTP connection via SQL http_request function\n # Connection base_path is /v1/vaults/{vault_id}, so we add /{table_name}\n json_payload = json.dumps(payload).replace(\"'\", \"''\")\n tokenize_path = f\"/{skyflow_table}\"\n \n # Execute tokenization via UC connection\n result_df = spark.sql(f\"\"\"\n SELECT http_request(\n conn => 'skyflow_conn',\n method => 'POST',\n path => '{tokenize_path}',\n headers => map(\n 'Content-Type', 'application/json',\n 'Accept', 'application/json'\n ),\n json => '{json_payload}'\n ) as full_response\n \"\"\")\n \n # Parse response\n full_response = result_df.collect()[0]['full_response']\n result = json.loads(full_response.text)\n \n # Fail fast if API response indicates error\n if \"error\" in result:\n raise RuntimeError(f\"Skyflow API error: {result['error']}\")\n \n if \"records\" not in result:\n raise RuntimeError(f\"Invalid Skyflow API response - missing 'records': {result}\")\n \n # Extract tokens in order\n tokens = []\n for i, record in enumerate(result.get(\"records\", [])):\n if \"tokens\" in record and table_column in record[\"tokens\"]:\n token = record[\"tokens\"][table_column]\n tokens.append(token)\n else:\n raise RuntimeError(f\"Tokenization failed for value {i+1} in {column_name}. Record: {record}\")\n \n successful_tokens = len([t for i, t in enumerate(tokens) if t and str(t) != str(values[i])])\n print(f\" Successfully tokenized {successful_tokens}/{len(values)} values\")\n \n return tokens\n\ndef perform_chunked_merge(table_name, column, update_data):\n \"\"\"\n Perform MERGE operations in chunks to avoid memory/timeout issues.\n Returns total number of rows updated.\n \"\"\"\n if not update_data:\n return 0\n \n total_updated = 0\n chunk_size = MAX_MERGE_BATCH_SIZE\n total_chunks = (len(update_data) + chunk_size - 1) // chunk_size\n \n print(f\" Splitting {len(update_data):,} updates into {total_chunks} MERGE operations (max {chunk_size:,} per chunk)\")\n \n for chunk_idx in range(0, len(update_data), chunk_size):\n chunk_data = update_data[chunk_idx:chunk_idx + chunk_size]\n chunk_num = (chunk_idx // chunk_size) + 1\n \n # Create temporary view for this chunk\n temp_df = spark.createDataFrame(chunk_data, [\"customer_id\", f\"new_{column}\"])\n temp_view_name = f\"temp_{column}_chunk_{chunk_num}_{hash(column) % 1000}\"\n temp_df.createOrReplaceTempView(temp_view_name)\n \n # Perform MERGE operation for this chunk\n merge_sql = f\"\"\"\n MERGE INTO `{table_name}` AS target\n USING {temp_view_name} AS source\n ON target.customer_id = source.customer_id\n WHEN MATCHED THEN \n UPDATE SET `{column}` = source.new_{column}\n \"\"\"\n \n spark.sql(merge_sql)\n chunk_updated = len(chunk_data)\n total_updated += chunk_updated\n \n print(f\" Chunk {chunk_num}/{total_chunks}: Updated {chunk_updated:,} rows\")\n \n # Clean up temp view\n spark.catalog.dropTempView(temp_view_name)\n \n return total_updated\n\n# Process each column individually (streaming approach)\nprint(\"Starting column-by-column tokenization with streaming chunked processing...\")\n\nfor column in pii_columns:\n print(f\"\\nProcessing column: {column}\")\n \n # Get total count first for progress tracking\n total_count = spark.sql(f\"\"\"\n SELECT COUNT(*) as count \n FROM `{table_name}` \n WHERE `{column}` IS NOT NULL\n \"\"\").collect()[0]['count']\n \n if total_count == 0:\n print(f\" No data found in column {column}\")\n continue\n \n print(f\" Found {total_count:,} total values to tokenize\")\n \n # Process in streaming chunks to avoid memory issues\n all_update_data = [] # Collect all updates before final MERGE\n processed_count = 0\n \n for offset in range(0, total_count, COLLECT_BATCH_SIZE):\n chunk_size = min(COLLECT_BATCH_SIZE, total_count - offset)\n print(f\" Processing chunk {offset//COLLECT_BATCH_SIZE + 1} ({chunk_size:,} records, offset {offset:,})...\")\n \n # Get chunk of data from Spark\n chunk_df = spark.sql(f\"\"\"\n SELECT customer_id, `{column}` \n FROM `{table_name}` \n WHERE `{column}` IS NOT NULL \n ORDER BY customer_id\n LIMIT {chunk_size} OFFSET {offset}\n \"\"\")\n \n chunk_rows = chunk_df.collect()\n if not chunk_rows:\n continue\n \n # Extract customer IDs and values for this chunk\n chunk_customer_ids = [row['customer_id'] for row in chunk_rows]\n chunk_column_values = [row[column] for row in chunk_rows]\n \n # Tokenize this chunk's values in Skyflow API batches\n chunk_tokens = []\n if len(chunk_column_values) <= SKYFLOW_BATCH_SIZE: # Single API batch\n chunk_tokens = tokenize_column_values(f\"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}\", chunk_column_values)\n else: # Multiple API batches within this chunk\n for i in range(0, len(chunk_column_values), SKYFLOW_BATCH_SIZE):\n api_batch_values = chunk_column_values[i:i + SKYFLOW_BATCH_SIZE]\n api_batch_tokens = tokenize_column_values(f\"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}_api_{i//SKYFLOW_BATCH_SIZE + 1}\", api_batch_values)\n chunk_tokens.extend(api_batch_tokens)\n \n # Verify token count matches input count (fail fast)\n if len(chunk_tokens) != len(chunk_customer_ids):\n raise RuntimeError(f\"Token count mismatch: got {len(chunk_tokens)} tokens for {len(chunk_customer_ids)} input values\")\n \n # Collect update data for rows that changed in this chunk\n chunk_original_map = {chunk_customer_ids[i]: chunk_column_values[i] for i in range(len(chunk_customer_ids))}\n \n for i, (customer_id, token) in enumerate(zip(chunk_customer_ids, chunk_tokens)):\n if token and str(token) != str(chunk_original_map[customer_id]):\n all_update_data.append((customer_id, token))\n \n processed_count += len(chunk_rows)\n print(f\" Processed {processed_count:,}/{total_count:,} records ({(processed_count/total_count)*100:.1f}%)\")\n \n # Perform final chunked MERGE operations for all collected updates\n if all_update_data:\n print(f\" Performing final chunked MERGE of {len(all_update_data):,} changed rows...\")\n total_updated = perform_chunked_merge(table_name, column, all_update_data)\n print(f\" ✓ Successfully updated {total_updated:,} rows in column {column}\")\n else:\n print(f\" No updates needed - all tokens match original values\")\n\nprint(\"\\nOptimized streaming tokenization completed!\")\n\n# Verify results\nprint(\"\\nFinal verification:\")\nfor column in pii_columns:\n sample_df = spark.sql(f\"\"\"\n SELECT `{column}`, COUNT(*) as count \n FROM `{table_name}` \n GROUP BY `{column}` \n LIMIT 3\n \"\"\")\n print(f\"\\nSample values in {column}:\")\n sample_df.show(truncate=False)\n\ntotal_rows = spark.sql(f\"SELECT COUNT(*) as count FROM `{table_name}`\").collect()[0][\"count\"]\nprint(f\"\\nTable size: {total_rows:,} total rows\")\n\nprint(f\"Optimized streaming tokenization completed for {len(pii_columns)} columns\")" + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/sql/destroy/drop_functions.sql b/skyflow_databricks/templates/sql/destroy/drop_functions.sql similarity index 100% rename from sql/destroy/drop_functions.sql rename to skyflow_databricks/templates/sql/destroy/drop_functions.sql diff --git a/skyflow_databricks/templates/sql/destroy/drop_table.sql b/skyflow_databricks/templates/sql/destroy/drop_table.sql new file mode 100644 index 0000000..0236f01 --- /dev/null +++ b/skyflow_databricks/templates/sql/destroy/drop_table.sql @@ -0,0 +1,5 @@ +-- Drop sample customer data table during cleanup +-- Set catalog context +USE CATALOG ${PREFIX}_catalog; + +DROP TABLE IF EXISTS ${PREFIX}_catalog.default.${PREFIX}_customer_data; \ No newline at end of file diff --git a/skyflow_databricks/templates/sql/destroy/remove_column_masks.sql b/skyflow_databricks/templates/sql/destroy/remove_column_masks.sql new file mode 100644 index 0000000..970925e --- /dev/null +++ b/skyflow_databricks/templates/sql/destroy/remove_column_masks.sql @@ -0,0 +1,12 @@ +-- Remove column masks before dropping table during cleanup +-- Remove masks from all tokenized PII columns during cleanup + +-- Set catalog context +USE CATALOG ${PREFIX}_catalog; + +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN first_name DROP MASK; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN last_name DROP MASK; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN email DROP MASK; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN phone_number DROP MASK; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN address DROP MASK; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN date_of_birth DROP MASK; \ No newline at end of file diff --git a/skyflow_databricks/templates/sql/setup/apply_column_masks.sql b/skyflow_databricks/templates/sql/setup/apply_column_masks.sql new file mode 100644 index 0000000..f206eb9 --- /dev/null +++ b/skyflow_databricks/templates/sql/setup/apply_column_masks.sql @@ -0,0 +1,14 @@ +-- Apply column masks to PII columns using Unity Catalog SQL-only functions +-- Pure SQL performance with UC connections - zero Python UDF overhead +-- Role-based access: auditors see detokenized, customer_service sees masked, others see tokens + +-- Set catalog context +USE CATALOG ${PREFIX}_catalog; + +-- Apply masks to all tokenized PII columns +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN first_name SET MASK ${PREFIX}_catalog.default.${PREFIX}_skyflow_mask_detokenize; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN last_name SET MASK ${PREFIX}_catalog.default.${PREFIX}_skyflow_mask_detokenize; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN email SET MASK ${PREFIX}_catalog.default.${PREFIX}_skyflow_mask_detokenize; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN phone_number SET MASK ${PREFIX}_catalog.default.${PREFIX}_skyflow_mask_detokenize; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN address SET MASK ${PREFIX}_catalog.default.${PREFIX}_skyflow_mask_detokenize; +ALTER TABLE ${PREFIX}_catalog.default.${PREFIX}_customer_data ALTER COLUMN date_of_birth SET MASK ${PREFIX}_catalog.default.${PREFIX}_skyflow_mask_detokenize; \ No newline at end of file diff --git a/sql/setup/create_sample_table.sql b/skyflow_databricks/templates/sql/setup/create_sample_table.sql similarity index 92% rename from sql/setup/create_sample_table.sql rename to skyflow_databricks/templates/sql/setup/create_sample_table.sql index f916688..e0da997 100644 --- a/sql/setup/create_sample_table.sql +++ b/skyflow_databricks/templates/sql/setup/create_sample_table.sql @@ -1,4 +1,7 @@ -CREATE TABLE IF NOT EXISTS ${PREFIX}_customer_data ( +-- Set catalog context +USE CATALOG ${PREFIX}_catalog; + +CREATE TABLE IF NOT EXISTS ${PREFIX}_catalog.default.${PREFIX}_customer_data ( customer_id STRING NOT NULL, first_name STRING, last_name STRING, @@ -18,7 +21,7 @@ CREATE TABLE IF NOT EXISTS ${PREFIX}_customer_data ( updated_at TIMESTAMP ); -INSERT INTO ${PREFIX}_customer_data ( +INSERT INTO ${PREFIX}_catalog.default.${PREFIX}_customer_data ( customer_id, first_name, last_name, @@ -39,7 +42,7 @@ INSERT INTO ${PREFIX}_customer_data ( ) WITH numbered_rows AS ( SELECT - posexplode(array_repeat(1, 50)) AS (id, _) + posexplode(array_repeat(1, 20)) AS (id, _) ), base_data AS ( SELECT diff --git a/sql/setup/create_uc_connections.sql b/skyflow_databricks/templates/sql/setup/create_uc_connections.sql similarity index 63% rename from sql/setup/create_uc_connections.sql rename to skyflow_databricks/templates/sql/setup/create_uc_connections.sql index 6ea591b..bac045d 100644 --- a/sql/setup/create_uc_connections.sql +++ b/skyflow_databricks/templates/sql/setup/create_uc_connections.sql @@ -1,11 +1,12 @@ -- Unity Catalog HTTP connections for Skyflow API integration -- These must be created without catalog context (global metastore resources) --- Single consolidated Skyflow connection for both tokenization and detokenization +-- Single Skyflow connection with base_path ending in vault_id +-- Tokenization adds /{table_name}, detokenization adds /detokenize CREATE CONNECTION IF NOT EXISTS skyflow_conn TYPE HTTP OPTIONS ( host '${SKYFLOW_VAULT_URL}', port 443, - base_path '/v1/vaults', + base_path '/v1/vaults/${SKYFLOW_VAULT_ID}', bearer_token secret('skyflow-secrets', 'skyflow_pat_token') ); \ No newline at end of file diff --git a/sql/setup/setup_uc_connections_api.sql b/skyflow_databricks/templates/sql/setup/setup_uc_connections_api.sql similarity index 54% rename from sql/setup/setup_uc_connections_api.sql rename to skyflow_databricks/templates/sql/setup/setup_uc_connections_api.sql index f24e161..dcf123a 100644 --- a/sql/setup/setup_uc_connections_api.sql +++ b/skyflow_databricks/templates/sql/setup/setup_uc_connections_api.sql @@ -1,35 +1,17 @@ --- Unity Catalog Connections Setup via REST API --- Complete pure SQL implementation matching Python UDF functionality --- This provides zero Python overhead with native Spark SQL performance +-- Unity Catalog Connections Setup via SQL Functions +-- Complete pure SQL implementation with zero Python overhead +-- Uses single skyflow_conn connection with dynamic paths --- Connections are created via REST API in setup.sh: --- --- Tokenization connection: --- { --- "name": "skyflow_tokenize_conn", --- "connection_type": "HTTP", --- "options": { --- "host": "${SKYFLOW_VAULT_URL}", --- "port": "443", --- "base_path": "/v1/vaults/${SKYFLOW_VAULT_ID}/${SKYFLOW_TABLE}", --- "bearer_token": "{{secrets.skyflow-secrets.skyflow_pat_token}}" --- } --- } --- --- Detokenization connection: --- { --- "name": "skyflow_detokenize_conn", --- "connection_type": "HTTP", --- "options": { --- "host": "${SKYFLOW_VAULT_URL}", --- "port": "443", --- "base_path": "/v1/vaults/${SKYFLOW_VAULT_ID}", --- "bearer_token": "{{secrets.skyflow-secrets.skyflow_pat_token}}" --- } --- } +-- Set catalog context for functions +USE CATALOG ${PREFIX}_catalog; + +-- Connection is created via SQL in create_uc_connections.sql: +-- Single connection with base_path '/v1/vaults/${SKYFLOW_VAULT_ID}' +-- Tokenization uses path '/${SKYFLOW_TABLE}' +-- Detokenization uses path '/detokenize' -- Core detokenization function with configurable redaction level -CREATE OR REPLACE FUNCTION ${PREFIX}_skyflow_uc_detokenize(token STRING, redaction_level STRING) +CREATE OR REPLACE FUNCTION ${PREFIX}_catalog.default.${PREFIX}_skyflow_uc_detokenize(token STRING, redaction_level STRING) RETURNS STRING LANGUAGE SQL DETERMINISTIC @@ -45,7 +27,7 @@ RETURN http_request( conn => 'skyflow_conn', method => 'POST', - path => '${SKYFLOW_VAULT_ID}/detokenize', + path => '/detokenize', headers => map( 'Content-Type', 'application/json', 'Accept', 'application/json' @@ -69,7 +51,7 @@ RETURN -- Multi-level conditional detokenization function with role-based redaction -- Supports PLAIN_TEXT, MASKED, and token-only based on user group membership -CREATE OR REPLACE FUNCTION ${PREFIX}_skyflow_conditional_detokenize(token STRING) +CREATE OR REPLACE FUNCTION ${PREFIX}_catalog.default.${PREFIX}_skyflow_conditional_detokenize(token STRING) RETURNS STRING LANGUAGE SQL DETERMINISTIC @@ -77,17 +59,17 @@ RETURN CASE -- Auditors get plain text (full detokenization) WHEN is_account_group_member('auditor') OR is_member('auditor') THEN - ${PREFIX}_skyflow_uc_detokenize(token, 'PLAIN_TEXT') + ${PREFIX}_catalog.default.${PREFIX}_skyflow_uc_detokenize(token, 'PLAIN_TEXT') -- Customer service gets masked data (partial redaction) WHEN is_account_group_member('customer_service') OR is_member('customer_service') THEN - ${PREFIX}_skyflow_uc_detokenize(token, 'MASKED') + ${PREFIX}_catalog.default.${PREFIX}_skyflow_uc_detokenize(token, 'MASKED') -- Marketing and all other users get tokens without API overhead ELSE token END; -- Convenience function for column masks (uses conditional logic) -CREATE OR REPLACE FUNCTION ${PREFIX}_skyflow_mask_detokenize(token STRING) +CREATE OR REPLACE FUNCTION ${PREFIX}_catalog.default.${PREFIX}_skyflow_mask_detokenize(token STRING) RETURNS STRING LANGUAGE SQL DETERMINISTIC -RETURN ${PREFIX}_skyflow_conditional_detokenize(token); \ No newline at end of file +RETURN ${PREFIX}_catalog.default.${PREFIX}_skyflow_conditional_detokenize(token); \ No newline at end of file diff --git a/skyflow_databricks/templates/sql/verify/verify_functions.sql b/skyflow_databricks/templates/sql/verify/verify_functions.sql new file mode 100644 index 0000000..f40d192 --- /dev/null +++ b/skyflow_databricks/templates/sql/verify/verify_functions.sql @@ -0,0 +1,6 @@ +-- Verify Unity Catalog detokenization functions exist +-- Set catalog context first +USE CATALOG ${PREFIX}_catalog; + +DESCRIBE FUNCTION ${PREFIX}_catalog.default.${PREFIX}_skyflow_uc_detokenize; +DESCRIBE FUNCTION ${PREFIX}_catalog.default.${PREFIX}_skyflow_mask_detokenize; \ No newline at end of file diff --git a/skyflow_databricks/utils/__init__.py b/skyflow_databricks/utils/__init__.py new file mode 100644 index 0000000..1f5313d --- /dev/null +++ b/skyflow_databricks/utils/__init__.py @@ -0,0 +1 @@ +# Utility functions \ No newline at end of file diff --git a/skyflow_databricks/utils/logging.py b/skyflow_databricks/utils/logging.py new file mode 100644 index 0000000..764bb32 --- /dev/null +++ b/skyflow_databricks/utils/logging.py @@ -0,0 +1,25 @@ +"""Logging configuration for the setup process.""" + +import logging +import sys +from rich.logging import RichHandler +from rich.console import Console + +console = Console() + + +def setup_logging(level: str = "INFO") -> logging.Logger: + """Setup logging with Rich handler for beautiful output.""" + + # Convert string level to logging constant + numeric_level = getattr(logging, level.upper(), logging.INFO) + + # Configure logging + logging.basicConfig( + level=numeric_level, + format="%(message)s", + datefmt="[%X]", + handlers=[RichHandler(console=console, rich_tracebacks=True)] + ) + + return logging.getLogger("skyflow_setup") \ No newline at end of file diff --git a/skyflow_databricks/utils/validation.py b/skyflow_databricks/utils/validation.py new file mode 100644 index 0000000..dffede4 --- /dev/null +++ b/skyflow_databricks/utils/validation.py @@ -0,0 +1,64 @@ +"""Input validation utilities.""" + +import re +from typing import List, Tuple, Optional + + +def validate_prefix(prefix: str) -> Tuple[bool, Optional[str]]: + """Validate prefix name meets Databricks naming requirements.""" + if not prefix: + return False, "Prefix cannot be empty" + + if not re.match(r'^[a-zA-Z][a-zA-Z0-9_]*$', prefix): + return False, "Prefix must start with a letter and contain only letters, numbers, and underscores" + + if len(prefix) > 50: + return False, "Prefix cannot be longer than 50 characters" + + # Reserved keywords + reserved = ['system', 'information_schema', 'default', 'main', 'hive_metastore'] + if prefix.lower() in reserved: + return False, f"Prefix '{prefix}' is reserved and cannot be used" + + return True, None + + +def validate_warehouse_id(warehouse_id: str) -> Tuple[bool, Optional[str]]: + """Validate warehouse ID format.""" + if not warehouse_id: + return False, "Warehouse ID cannot be empty" + + # Basic format validation - Databricks warehouse IDs are typically UUIDs or similar + if not re.match(r'^[a-zA-Z0-9\-_]{10,}$', warehouse_id): + return False, "Warehouse ID format appears invalid" + + return True, None + + +def validate_url(url: str, name: str = "URL") -> Tuple[bool, Optional[str]]: + """Validate URL format.""" + if not url: + return False, f"{name} cannot be empty" + + if not re.match(r'^https?://', url): + return False, f"{name} must start with http:// or https://" + + return True, None + + +def validate_required_files(file_paths: List[str]) -> Tuple[bool, List[str]]: + """Validate that required files exist.""" + import os + from pathlib import Path + + missing_files = [] + # Get templates directory relative to this module + template_dir = Path(__file__).parent.parent / "templates" + + for file_path in file_paths: + # Check in templates directory + template_path = template_dir / file_path + if not template_path.exists(): + missing_files.append(file_path) + + return len(missing_files) == 0, missing_files \ No newline at end of file diff --git a/sql/destroy/cleanup_catalog.sql b/sql/destroy/cleanup_catalog.sql deleted file mode 100644 index 1c874a2..0000000 --- a/sql/destroy/cleanup_catalog.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Destroy dedicated catalog and all its contents -DROP CATALOG IF EXISTS ${PREFIX}_catalog CASCADE; \ No newline at end of file diff --git a/sql/destroy/drop_table.sql b/sql/destroy/drop_table.sql deleted file mode 100644 index 2b10fba..0000000 --- a/sql/destroy/drop_table.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Drop sample customer data table during cleanup -DROP TABLE IF EXISTS ${PREFIX}_customer_data; \ No newline at end of file diff --git a/sql/destroy/remove_column_masks.sql b/sql/destroy/remove_column_masks.sql deleted file mode 100644 index d7d061d..0000000 --- a/sql/destroy/remove_column_masks.sql +++ /dev/null @@ -1,8 +0,0 @@ --- Remove column masks before dropping table during cleanup --- Note: Only first_name should have a mask in current setup, but including all for safety -ALTER TABLE ${PREFIX}_customer_data ALTER COLUMN first_name DROP MASK; -ALTER TABLE ${PREFIX}_customer_data ALTER COLUMN last_name DROP MASK; -ALTER TABLE ${PREFIX}_customer_data ALTER COLUMN email DROP MASK; -ALTER TABLE ${PREFIX}_customer_data ALTER COLUMN phone_number DROP MASK; -ALTER TABLE ${PREFIX}_customer_data ALTER COLUMN address DROP MASK; -ALTER TABLE ${PREFIX}_customer_data ALTER COLUMN date_of_birth DROP MASK; \ No newline at end of file diff --git a/sql/setup/apply_column_masks.sql b/sql/setup/apply_column_masks.sql deleted file mode 100644 index 17f5d71..0000000 --- a/sql/setup/apply_column_masks.sql +++ /dev/null @@ -1,4 +0,0 @@ --- Apply column masks to PII columns using Unity Catalog SQL-only functions --- Pure SQL performance with UC connections - zero Python UDF overhead --- Only auditors get detokenized data, others see raw tokens (optimized for performance) -ALTER TABLE ${PREFIX}_customer_data ALTER COLUMN first_name SET MASK ${PREFIX}_skyflow_mask_detokenize; \ No newline at end of file diff --git a/sql/setup/create_catalog.sql b/sql/setup/create_catalog.sql deleted file mode 100644 index 18c2e62..0000000 --- a/sql/setup/create_catalog.sql +++ /dev/null @@ -1,3 +0,0 @@ --- Create dedicated catalog and schema for this instance -CREATE CATALOG IF NOT EXISTS ${PREFIX}_catalog; -CREATE SCHEMA IF NOT EXISTS ${PREFIX}_catalog.default; \ No newline at end of file diff --git a/sql/verify/check_functions_exist.sql b/sql/verify/check_functions_exist.sql deleted file mode 100644 index 1531368..0000000 --- a/sql/verify/check_functions_exist.sql +++ /dev/null @@ -1,4 +0,0 @@ --- Check if functions still exist (used for destroy verification) --- These will error if functions don't exist, which is expected for verification -DESCRIBE FUNCTION ${PREFIX}_skyflow_uc_detokenize; -DESCRIBE FUNCTION ${PREFIX}_skyflow_mask_detokenize; \ No newline at end of file diff --git a/sql/verify/check_table_exists.sql b/sql/verify/check_table_exists.sql deleted file mode 100644 index 0ea004a..0000000 --- a/sql/verify/check_table_exists.sql +++ /dev/null @@ -1,3 +0,0 @@ --- Check if table still exists (used for destroy verification) --- Will error if table doesn't exist, which is expected for verification -DESCRIBE TABLE ${PREFIX}_customer_data; \ No newline at end of file diff --git a/sql/verify/verify_functions.sql b/sql/verify/verify_functions.sql deleted file mode 100644 index 99d6f9f..0000000 --- a/sql/verify/verify_functions.sql +++ /dev/null @@ -1,3 +0,0 @@ --- Verify Unity Catalog detokenization functions exist -DESCRIBE FUNCTION ${PREFIX}_skyflow_uc_detokenize; -DESCRIBE FUNCTION ${PREFIX}_skyflow_mask_detokenize; \ No newline at end of file diff --git a/sql/verify/verify_table.sql b/sql/verify/verify_table.sql deleted file mode 100644 index 59dc990..0000000 --- a/sql/verify/verify_table.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Verify sample table exists before applying column masks -DESCRIBE TABLE ${PREFIX}_customer_data; \ No newline at end of file