diff --git a/.env.local.example b/.env.local.example index cb16be3..ad8b38c 100644 --- a/.env.local.example +++ b/.env.local.example @@ -5,7 +5,6 @@ DATABRICKS_PAT_TOKEN=dapi123...your-pat-token DATABRICKS_METASTORE_REGION=us-west-1 # Skyflow Configuration -SKYFLOW_ACCOUNT_ID=your-account-id SKYFLOW_VAULT_URL=https://your-vault.vault.skyflowapis.com SKYFLOW_VAULT_ID=your-vault-id SKYFLOW_PAT_TOKEN=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9...your-pat-token diff --git a/README.md b/README.md index a3bdd3d..5a4eca0 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,6 @@ This solution provides secure data tokenization and detokenization capabilities # Skyflow Configuration SKYFLOW_VAULT_URL=https://your-vault.vault.skyflowapis.com SKYFLOW_PAT_TOKEN=eyJhbGci...your-pat-token - SKYFLOW_ACCOUNT_ID=your-account-id SKYFLOW_VAULT_ID=your-vault-id ``` @@ -300,7 +299,6 @@ WAREHOUSE_ID=abc123... # Skyflow Integration SKYFLOW_VAULT_URL=https://vault.skyflow.com SKYFLOW_VAULT_ID=abc123... -SKYFLOW_ACCOUNT_ID=acc123... SKYFLOW_PAT_TOKEN=sky123... SKYFLOW_TABLE=customer_data diff --git a/config.sh b/config.sh index d7939af..14a40f5 100644 --- a/config.sh +++ b/config.sh @@ -24,7 +24,6 @@ fi # Skyflow settings from .env.local (no hardcoded defaults) DEFAULT_SKYFLOW_VAULT_URL="$SKYFLOW_VAULT_URL" DEFAULT_SKYFLOW_VAULT_ID="$SKYFLOW_VAULT_ID" -DEFAULT_SKYFLOW_ACCOUNT_ID="$SKYFLOW_ACCOUNT_ID" DEFAULT_SKYFLOW_PAT_TOKEN="$SKYFLOW_PAT_TOKEN" DEFAULT_SKYFLOW_TABLE="$SKYFLOW_TABLE" @@ -40,7 +39,6 @@ export WAREHOUSE_ID=${WAREHOUSE_ID:-$DEFAULT_WAREHOUSE_ID} export SKYFLOW_VAULT_URL=${SKYFLOW_VAULT_URL:-$DEFAULT_SKYFLOW_VAULT_URL} export SKYFLOW_VAULT_ID=${SKYFLOW_VAULT_ID:-$DEFAULT_SKYFLOW_VAULT_ID} -export SKYFLOW_ACCOUNT_ID=${SKYFLOW_ACCOUNT_ID:-$DEFAULT_SKYFLOW_ACCOUNT_ID} export SKYFLOW_PAT_TOKEN=${SKYFLOW_PAT_TOKEN:-$DEFAULT_SKYFLOW_PAT_TOKEN} export SKYFLOW_TABLE=${SKYFLOW_TABLE:-$DEFAULT_SKYFLOW_TABLE} diff --git a/notebooks/notebook_tokenize_table.ipynb b/notebooks/notebook_tokenize_table.ipynb index 62ba30a..8a1eda7 100644 --- a/notebooks/notebook_tokenize_table.ipynb +++ b/notebooks/notebook_tokenize_table.ipynb @@ -6,7 +6,278 @@ "id": "cell-0", "metadata": {}, "outputs": [], - "source": "# Unity Catalog-aware serverless tokenization notebook \n# Optimized version with chunked processing and MERGE operations for maximum performance\nimport requests\nimport os\nfrom pyspark.sql import SparkSession\nfrom pyspark.dbutils import DBUtils\n\n# Initialize Spark session optimized for serverless compute\nspark = SparkSession.builder \\\n .appName(\"SkyflowTokenization\") \\\n .config(\"spark.databricks.cluster.profile\", \"serverless\") \\\n .config(\"spark.databricks.delta.autoCompact.enabled\", \"true\") \\\n .config(\"spark.sql.adaptive.enabled\", \"true\") \\\n .config(\"spark.sql.adaptive.coalescePartitions.enabled\", \"true\") \\\n .getOrCreate()\n \ndbutils = DBUtils(spark)\n\nprint(f\"✓ Running on Databricks serverless compute\")\nprint(f\"✓ Spark version: {spark.version}\")\n\n# Performance configuration\nMAX_MERGE_BATCH_SIZE = 10000 # Maximum records per MERGE operation\nCOLLECT_BATCH_SIZE = 1000 # Maximum records to collect() from Spark at once\n\n# Define widgets to receive input parameters\ndbutils.widgets.text(\"table_name\", \"\")\ndbutils.widgets.text(\"pii_columns\", \"\")\ndbutils.widgets.text(\"batch_size\", \"25\") # Skyflow API batch size\n\n# Read widget values\ntable_name = dbutils.widgets.get(\"table_name\")\npii_columns = dbutils.widgets.get(\"pii_columns\").split(\",\")\nSKYFLOW_BATCH_SIZE = int(dbutils.widgets.get(\"batch_size\"))\n\nif not table_name or not pii_columns:\n raise ValueError(\"Both 'table_name' and 'pii_columns' must be provided.\")\n\nprint(f\"Tokenizing table: {table_name}\")\nprint(f\"PII columns: {', '.join(pii_columns)}\")\nprint(f\"Skyflow API batch size: {SKYFLOW_BATCH_SIZE}\")\nprint(f\"MERGE batch size limit: {MAX_MERGE_BATCH_SIZE:,} records\")\nprint(f\"Collect batch size: {COLLECT_BATCH_SIZE:,} records\")\n\n# Extract catalog and schema from table name if fully qualified\nif '.' in table_name:\n parts = table_name.split('.')\n if len(parts) == 3: # catalog.schema.table\n catalog_name = parts[0]\n schema_name = parts[1]\n table_name_only = parts[2]\n \n # Set the catalog and schema context for this session\n print(f\"Setting catalog context to: {catalog_name}\")\n spark.sql(f\"USE CATALOG {catalog_name}\")\n spark.sql(f\"USE SCHEMA {schema_name}\")\n \n # Use the simple table name for queries since context is set\n table_name = table_name_only\n print(f\"✓ Catalog context set, using table name: {table_name}\")\n\n# Get Skyflow credentials from UC secrets (serverless-compatible)\ntry:\n SKYFLOW_VAULT_URL = dbutils.secrets.get(scope=\"skyflow-secrets\", key=\"skyflow_vault_url\")\n SKYFLOW_VAULT_ID = dbutils.secrets.get(scope=\"skyflow-secrets\", key=\"skyflow_vault_id\")\n SKYFLOW_ACCOUNT_ID = dbutils.secrets.get(scope=\"skyflow-secrets\", key=\"skyflow_account_id\")\n SKYFLOW_PAT_TOKEN = dbutils.secrets.get(scope=\"skyflow-secrets\", key=\"skyflow_pat_token\")\n SKYFLOW_TABLE = dbutils.secrets.get(scope=\"skyflow-secrets\", key=\"skyflow_table\")\n \n print(\"✓ Successfully retrieved credentials from UC secrets\")\nexcept Exception as e:\n print(f\"Error retrieving UC secrets: {e}\")\n raise ValueError(\"Could not retrieve Skyflow credentials from UC secrets\")\n\n# Build API URL\nSKYFLOW_API_URL = f\"{SKYFLOW_VAULT_URL}/v1/vaults/{SKYFLOW_VAULT_ID}/{SKYFLOW_TABLE}\"\n\ndef tokenize_column_values(column_name, values):\n \"\"\"\n Tokenize a list of PII values for a specific column via Skyflow API.\n Simplified - no deduplication, direct 1:1 mapping.\n Returns list of tokens in same order as input values.\n \"\"\"\n if not values:\n return []\n \n headers = {\n \"Content-Type\": \"application/json\",\n \"Accept\": \"application/json\", \n \"X-SKYFLOW-ACCOUNT-ID\": SKYFLOW_ACCOUNT_ID,\n \"Authorization\": f\"Bearer {SKYFLOW_PAT_TOKEN}\"\n }\n\n # Create records for each value (no deduplication)\n skyflow_records = [{\n \"fields\": {\"pii_values\": str(value)}\n } for value in values if value is not None]\n\n payload = {\n \"records\": skyflow_records,\n \"tokenization\": True\n }\n\n try:\n print(f\" Tokenizing {len(skyflow_records)} values for {column_name}\")\n response = requests.post(SKYFLOW_API_URL, headers=headers, json=payload, timeout=30)\n response.raise_for_status()\n result = response.json()\n \n # Extract tokens in order (1:1 mapping)\n tokens = []\n for i, record in enumerate(result.get(\"records\", [])):\n if \"tokens\" in record and \"pii_values\" in record[\"tokens\"]:\n token = record[\"tokens\"][\"pii_values\"]\n tokens.append(token)\n else:\n print(f\" Value {i+1}: failed to tokenize, keeping original\")\n tokens.append(values[i] if i < len(values) and values[i] is not None else None)\n \n return tokens\n \n except requests.exceptions.RequestException as e:\n print(f\" ❌ ERROR tokenizing {column_name}: {e}\")\n \n # Show detailed API error response for troubleshooting\n if hasattr(e, 'response') and e.response:\n try:\n error_details = e.response.json()\n print(f\" API Error Details: {error_details}\")\n except:\n print(f\" API Error Response: {e.response.text}\")\n print(f\" Status Code: {e.response.status_code}\")\n print(f\" Headers: {dict(e.response.headers)}\")\n \n # Return original values on error\n return [str(val) if val is not None else None for val in values]\n except Exception as e:\n print(f\" ❌ UNEXPECTED ERROR tokenizing {column_name}: {e}\")\n return [str(val) if val is not None else None for val in values]\n\ndef perform_chunked_merge(table_name, column, update_data):\n \"\"\"\n Perform MERGE operations in chunks to avoid memory/timeout issues.\n Returns total number of rows updated.\n \"\"\"\n if not update_data:\n return 0\n \n total_updated = 0\n chunk_size = MAX_MERGE_BATCH_SIZE\n total_chunks = (len(update_data) + chunk_size - 1) // chunk_size\n \n print(f\" Splitting {len(update_data):,} updates into {total_chunks} MERGE operations (max {chunk_size:,} per chunk)\")\n \n for chunk_idx in range(0, len(update_data), chunk_size):\n chunk_data = update_data[chunk_idx:chunk_idx + chunk_size]\n chunk_num = (chunk_idx // chunk_size) + 1\n \n try:\n # Create temporary view for this chunk\n temp_df = spark.createDataFrame(chunk_data, [\"customer_id\", f\"new_{column}\"])\n temp_view_name = f\"temp_{column}_chunk_{chunk_num}_{hash(column) % 1000}\"\n temp_df.createOrReplaceTempView(temp_view_name)\n \n # Perform MERGE operation for this chunk\n merge_sql = f\"\"\"\n MERGE INTO `{table_name}` AS target\n USING {temp_view_name} AS source\n ON target.customer_id = source.customer_id\n WHEN MATCHED THEN \n UPDATE SET `{column}` = source.new_{column}\n \"\"\"\n \n spark.sql(merge_sql)\n chunk_updated = len(chunk_data)\n total_updated += chunk_updated\n \n print(f\" Chunk {chunk_num}/{total_chunks}: Updated {chunk_updated:,} rows\")\n \n # Clean up temp view\n spark.catalog.dropTempView(temp_view_name)\n \n except Exception as e:\n print(f\" Error in chunk {chunk_num}: {e}\")\n print(f\" Falling back to row-by-row for this chunk...\")\n \n # Fallback to row-by-row for this chunk only\n chunk_fallback_count = 0\n for customer_id, token in chunk_data:\n try:\n spark.sql(f\"\"\"\n UPDATE `{table_name}` \n SET `{column}` = '{token}' \n WHERE customer_id = '{customer_id}'\n \"\"\")\n chunk_fallback_count += 1\n except Exception as row_e:\n print(f\" Error updating customer_id {customer_id}: {row_e}\")\n \n total_updated += chunk_fallback_count\n print(f\" Chunk {chunk_num} fallback: Updated {chunk_fallback_count} rows\")\n \n return total_updated\n\n# Process each column individually (streaming approach)\nprint(\"Starting column-by-column tokenization with streaming chunked processing...\")\n\nfor column in pii_columns:\n print(f\"\\nProcessing column: {column}\")\n \n # Get total count first for progress tracking\n total_count = spark.sql(f\"\"\"\n SELECT COUNT(*) as count \n FROM `{table_name}` \n WHERE `{column}` IS NOT NULL\n \"\"\").collect()[0]['count']\n \n if total_count == 0:\n print(f\" No data found in column {column}\")\n continue\n \n print(f\" Found {total_count:,} total values to tokenize\")\n \n # Process in streaming chunks to avoid memory issues\n all_update_data = [] # Collect all updates before final MERGE\n processed_count = 0\n \n for offset in range(0, total_count, COLLECT_BATCH_SIZE):\n chunk_size = min(COLLECT_BATCH_SIZE, total_count - offset)\n print(f\" Processing chunk {offset//COLLECT_BATCH_SIZE + 1} ({chunk_size:,} records, offset {offset:,})...\")\n \n # Get chunk of data from Spark\n chunk_df = spark.sql(f\"\"\"\n SELECT customer_id, `{column}` \n FROM `{table_name}` \n WHERE `{column}` IS NOT NULL \n ORDER BY customer_id\n LIMIT {chunk_size} OFFSET {offset}\n \"\"\")\n \n chunk_rows = chunk_df.collect()\n if not chunk_rows:\n continue\n \n # Extract customer IDs and values for this chunk\n chunk_customer_ids = [row['customer_id'] for row in chunk_rows]\n chunk_column_values = [row[column] for row in chunk_rows]\n \n # Tokenize this chunk's values in Skyflow API batches\n chunk_tokens = []\n if len(chunk_column_values) <= SKYFLOW_BATCH_SIZE: # Single API batch\n chunk_tokens = tokenize_column_values(f\"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}\", chunk_column_values)\n else: # Multiple API batches within this chunk\n for i in range(0, len(chunk_column_values), SKYFLOW_BATCH_SIZE):\n api_batch_values = chunk_column_values[i:i + SKYFLOW_BATCH_SIZE]\n api_batch_tokens = tokenize_column_values(f\"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}_api_{i//SKYFLOW_BATCH_SIZE + 1}\", api_batch_values)\n chunk_tokens.extend(api_batch_tokens)\n \n if len(chunk_tokens) != len(chunk_customer_ids):\n print(f\" Warning: Token count ({len(chunk_tokens):,}) doesn't match chunk row count ({len(chunk_customer_ids):,})\")\n continue\n \n # Collect update data for rows that changed in this chunk\n chunk_original_map = {chunk_customer_ids[i]: chunk_column_values[i] for i in range(len(chunk_customer_ids))}\n \n for i, (customer_id, token) in enumerate(zip(chunk_customer_ids, chunk_tokens)):\n if token and str(token) != str(chunk_original_map[customer_id]):\n all_update_data.append((customer_id, token))\n \n processed_count += len(chunk_rows)\n print(f\" Processed {processed_count:,}/{total_count:,} records ({(processed_count/total_count)*100:.1f}%)\")\n \n # Perform final chunked MERGE operations for all collected updates\n if all_update_data:\n print(f\" Performing final chunked MERGE of {len(all_update_data):,} changed rows...\")\n total_updated = perform_chunked_merge(table_name, column, all_update_data)\n print(f\" ✓ Successfully updated {total_updated:,} rows in column {column}\")\n else:\n print(f\" No updates needed - all tokens match original values\")\n\nprint(\"\\nOptimized streaming tokenization completed!\")\n\n# Verify results\nprint(\"\\nFinal verification:\")\nfor column in pii_columns:\n sample_df = spark.sql(f\"\"\"\n SELECT `{column}`, COUNT(*) as count \n FROM `{table_name}` \n GROUP BY `{column}` \n LIMIT 3\n \"\"\")\n print(f\"\\nSample values in {column}:\")\n sample_df.show(truncate=False)\n\ntotal_rows = spark.sql(f\"SELECT COUNT(*) as count FROM `{table_name}`\").collect()[0][\"count\"]\nprint(f\"\\nTable size: {total_rows:,} total rows\")\n\ndbutils.notebook.exit(f\"Optimized streaming tokenization completed for {len(pii_columns)} columns\")" + "source": [ + "# Unity Catalog-aware serverless tokenization notebook \n", + "# Uses dbutils.secrets.get() + UC HTTP connections for serverless compatibility\n", + "import json\n", + "import os\n", + "from pyspark.sql import SparkSession\n", + "from pyspark.dbutils import DBUtils\n", + "\n", + "# Initialize Spark session optimized for serverless compute\n", + "spark = SparkSession.builder \\\n", + " .appName(\"SkyflowTokenization\") \\\n", + " .config(\"spark.databricks.cluster.profile\", \"serverless\") \\\n", + " .config(\"spark.databricks.delta.autoCompact.enabled\", \"true\") \\\n", + " .config(\"spark.sql.adaptive.enabled\", \"true\") \\\n", + " .config(\"spark.sql.adaptive.coalescePartitions.enabled\", \"true\") \\\n", + " .getOrCreate()\n", + " \n", + "dbutils = DBUtils(spark)\n", + "\n", + "print(f\"✓ Running on Databricks serverless compute\")\n", + "print(f\"✓ Spark version: {spark.version}\")\n", + "\n", + "# Performance configuration\n", + "MAX_MERGE_BATCH_SIZE = 10000 # Maximum records per MERGE operation\n", + "COLLECT_BATCH_SIZE = 1000 # Maximum records to collect() from Spark at once\n", + "\n", + "# Define widgets to receive input parameters\n", + "dbutils.widgets.text(\"table_name\", \"\")\n", + "dbutils.widgets.text(\"pii_columns\", \"\")\n", + "dbutils.widgets.text(\"batch_size\", \"\") # Skyflow API batch size\n", + "\n", + "# Read widget values\n", + "table_name = dbutils.widgets.get(\"table_name\")\n", + "pii_columns = dbutils.widgets.get(\"pii_columns\").split(\",\")\n", + "SKYFLOW_BATCH_SIZE = int(dbutils.widgets.get(\"batch_size\"))\n", + "\n", + "if not table_name or not pii_columns:\n", + " raise ValueError(\"Both 'table_name' and 'pii_columns' must be provided.\")\n", + "\n", + "print(f\"Tokenizing table: {table_name}\")\n", + "print(f\"PII columns: {', '.join(pii_columns)}\")\n", + "print(f\"Skyflow API batch size: {SKYFLOW_BATCH_SIZE}\")\n", + "print(f\"MERGE batch size limit: {MAX_MERGE_BATCH_SIZE:,} records\")\n", + "print(f\"Collect batch size: {COLLECT_BATCH_SIZE:,} records\")\n", + "\n", + "# Extract catalog and schema from table name if fully qualified\n", + "if '.' in table_name:\n", + " parts = table_name.split('.')\n", + " if len(parts) == 3: # catalog.schema.table\n", + " catalog_name = parts[0]\n", + " schema_name = parts[1]\n", + " table_name_only = parts[2]\n", + " \n", + " # Set the catalog and schema context for this session\n", + " print(f\"Setting catalog context to: {catalog_name}\")\n", + " spark.sql(f\"USE CATALOG {catalog_name}\")\n", + " spark.sql(f\"USE SCHEMA {schema_name}\")\n", + " \n", + " # Use the simple table name for queries since context is set\n", + " table_name = table_name_only\n", + " print(f\"✓ Catalog context set, using table name: {table_name}\")\n", + "\n", + "print(\"✓ Using dbutils.secrets.get() + UC HTTP connections for serverless compatibility\")\n", + "\n", + "def tokenize_column_values(column_name, values):\n", + " \"\"\"\n", + " Tokenize a list of PII values using Unity Catalog HTTP connection.\n", + " Uses dbutils.secrets.get() and http_request() for serverless compatibility.\n", + " Returns list of tokens in same order as input values.\n", + " \"\"\"\n", + " if not values:\n", + " return []\n", + " \n", + " # Get secrets using dbutils (works in serverless)\n", + " table_column = dbutils.secrets.get(\"skyflow-secrets\", \"skyflow_table_column\")\n", + " vault_id = dbutils.secrets.get(\"skyflow-secrets\", \"skyflow_vault_id\")\n", + " skyflow_table = dbutils.secrets.get(\"skyflow-secrets\", \"skyflow_table\")\n", + " \n", + " # Create records for each value\n", + " skyflow_records = [{\n", + " \"fields\": {table_column: str(value)}\n", + " } for value in values if value is not None]\n", + "\n", + " # Create Skyflow tokenization payload\n", + " payload = {\n", + " \"records\": skyflow_records,\n", + " \"tokenization\": True\n", + " }\n", + "\n", + " print(f\" Tokenizing {len(skyflow_records)} values for {column_name}\")\n", + " \n", + " # Use Unity Catalog HTTP connection via SQL http_request function\n", + " json_payload = json.dumps(payload).replace(\"'\", \"''\")\n", + " tokenize_path = f\"{vault_id}/{skyflow_table}\"\n", + " \n", + " # Execute tokenization via consolidated UC connection\n", + " result_df = spark.sql(f\"\"\"\n", + " SELECT http_request(\n", + " conn => 'skyflow_conn',\n", + " method => 'POST',\n", + " path => '{tokenize_path}',\n", + " headers => map(\n", + " 'Content-Type', 'application/json',\n", + " 'Accept', 'application/json'\n", + " ),\n", + " json => '{json_payload}'\n", + " ) as full_response\n", + " \"\"\")\n", + " \n", + " # Parse response\n", + " full_response = result_df.collect()[0]['full_response']\n", + " result = json.loads(full_response.text)\n", + " \n", + " # Fail fast if API response indicates error\n", + " if \"error\" in result:\n", + " raise RuntimeError(f\"Skyflow API error: {result['error']}\")\n", + " \n", + " if \"records\" not in result:\n", + " raise RuntimeError(f\"Invalid Skyflow API response - missing 'records': {result}\")\n", + " \n", + " # Extract tokens in order\n", + " tokens = []\n", + " for i, record in enumerate(result.get(\"records\", [])):\n", + " if \"tokens\" in record and table_column in record[\"tokens\"]:\n", + " token = record[\"tokens\"][table_column]\n", + " tokens.append(token)\n", + " else:\n", + " raise RuntimeError(f\"Tokenization failed for value {i+1} in {column_name}. Record: {record}\")\n", + " \n", + " successful_tokens = len([t for i, t in enumerate(tokens) if t and str(t) != str(values[i])])\n", + " print(f\" Successfully tokenized {successful_tokens}/{len(values)} values\")\n", + " \n", + " return tokens\n", + "\n", + "def perform_chunked_merge(table_name, column, update_data):\n", + " \"\"\"\n", + " Perform MERGE operations in chunks to avoid memory/timeout issues.\n", + " Returns total number of rows updated.\n", + " \"\"\"\n", + " if not update_data:\n", + " return 0\n", + " \n", + " total_updated = 0\n", + " chunk_size = MAX_MERGE_BATCH_SIZE\n", + " total_chunks = (len(update_data) + chunk_size - 1) // chunk_size\n", + " \n", + " print(f\" Splitting {len(update_data):,} updates into {total_chunks} MERGE operations (max {chunk_size:,} per chunk)\")\n", + " \n", + " for chunk_idx in range(0, len(update_data), chunk_size):\n", + " chunk_data = update_data[chunk_idx:chunk_idx + chunk_size]\n", + " chunk_num = (chunk_idx // chunk_size) + 1\n", + " \n", + " # Create temporary view for this chunk\n", + " temp_df = spark.createDataFrame(chunk_data, [\"customer_id\", f\"new_{column}\"])\n", + " temp_view_name = f\"temp_{column}_chunk_{chunk_num}_{hash(column) % 1000}\"\n", + " temp_df.createOrReplaceTempView(temp_view_name)\n", + " \n", + " # Perform MERGE operation for this chunk\n", + " merge_sql = f\"\"\"\n", + " MERGE INTO `{table_name}` AS target\n", + " USING {temp_view_name} AS source\n", + " ON target.customer_id = source.customer_id\n", + " WHEN MATCHED THEN \n", + " UPDATE SET `{column}` = source.new_{column}\n", + " \"\"\"\n", + " \n", + " spark.sql(merge_sql)\n", + " chunk_updated = len(chunk_data)\n", + " total_updated += chunk_updated\n", + " \n", + " print(f\" Chunk {chunk_num}/{total_chunks}: Updated {chunk_updated:,} rows\")\n", + " \n", + " # Clean up temp view\n", + " spark.catalog.dropTempView(temp_view_name)\n", + " \n", + " return total_updated\n", + "\n", + "# Process each column individually (streaming approach)\n", + "print(\"Starting column-by-column tokenization with streaming chunked processing...\")\n", + "\n", + "for column in pii_columns:\n", + " print(f\"\\nProcessing column: {column}\")\n", + " \n", + " # Get total count first for progress tracking\n", + " total_count = spark.sql(f\"\"\"\n", + " SELECT COUNT(*) as count \n", + " FROM `{table_name}` \n", + " WHERE `{column}` IS NOT NULL\n", + " \"\"\").collect()[0]['count']\n", + " \n", + " if total_count == 0:\n", + " print(f\" No data found in column {column}\")\n", + " continue\n", + " \n", + " print(f\" Found {total_count:,} total values to tokenize\")\n", + " \n", + " # Process in streaming chunks to avoid memory issues\n", + " all_update_data = [] # Collect all updates before final MERGE\n", + " processed_count = 0\n", + " \n", + " for offset in range(0, total_count, COLLECT_BATCH_SIZE):\n", + " chunk_size = min(COLLECT_BATCH_SIZE, total_count - offset)\n", + " print(f\" Processing chunk {offset//COLLECT_BATCH_SIZE + 1} ({chunk_size:,} records, offset {offset:,})...\")\n", + " \n", + " # Get chunk of data from Spark\n", + " chunk_df = spark.sql(f\"\"\"\n", + " SELECT customer_id, `{column}` \n", + " FROM `{table_name}` \n", + " WHERE `{column}` IS NOT NULL \n", + " ORDER BY customer_id\n", + " LIMIT {chunk_size} OFFSET {offset}\n", + " \"\"\")\n", + " \n", + " chunk_rows = chunk_df.collect()\n", + " if not chunk_rows:\n", + " continue\n", + " \n", + " # Extract customer IDs and values for this chunk\n", + " chunk_customer_ids = [row['customer_id'] for row in chunk_rows]\n", + " chunk_column_values = [row[column] for row in chunk_rows]\n", + " \n", + " # Tokenize this chunk's values in Skyflow API batches\n", + " chunk_tokens = []\n", + " if len(chunk_column_values) <= SKYFLOW_BATCH_SIZE: # Single API batch\n", + " chunk_tokens = tokenize_column_values(f\"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}\", chunk_column_values)\n", + " else: # Multiple API batches within this chunk\n", + " for i in range(0, len(chunk_column_values), SKYFLOW_BATCH_SIZE):\n", + " api_batch_values = chunk_column_values[i:i + SKYFLOW_BATCH_SIZE]\n", + " api_batch_tokens = tokenize_column_values(f\"{column}_chunk_{offset//COLLECT_BATCH_SIZE + 1}_api_{i//SKYFLOW_BATCH_SIZE + 1}\", api_batch_values)\n", + " chunk_tokens.extend(api_batch_tokens)\n", + " \n", + " # Verify token count matches input count (fail fast)\n", + " if len(chunk_tokens) != len(chunk_customer_ids):\n", + " raise RuntimeError(f\"Token count mismatch: got {len(chunk_tokens)} tokens for {len(chunk_customer_ids)} input values\")\n", + " \n", + " # Collect update data for rows that changed in this chunk\n", + " chunk_original_map = {chunk_customer_ids[i]: chunk_column_values[i] for i in range(len(chunk_customer_ids))}\n", + " \n", + " for i, (customer_id, token) in enumerate(zip(chunk_customer_ids, chunk_tokens)):\n", + " if token and str(token) != str(chunk_original_map[customer_id]):\n", + " all_update_data.append((customer_id, token))\n", + " \n", + " processed_count += len(chunk_rows)\n", + " print(f\" Processed {processed_count:,}/{total_count:,} records ({(processed_count/total_count)*100:.1f}%)\")\n", + " \n", + " # Perform final chunked MERGE operations for all collected updates\n", + " if all_update_data:\n", + " print(f\" Performing final chunked MERGE of {len(all_update_data):,} changed rows...\")\n", + " total_updated = perform_chunked_merge(table_name, column, all_update_data)\n", + " print(f\" ✓ Successfully updated {total_updated:,} rows in column {column}\")\n", + " else:\n", + " print(f\" No updates needed - all tokens match original values\")\n", + "\n", + "print(\"\\nOptimized streaming tokenization completed!\")\n", + "\n", + "# Verify results\n", + "print(\"\\nFinal verification:\")\n", + "for column in pii_columns:\n", + " sample_df = spark.sql(f\"\"\"\n", + " SELECT `{column}`, COUNT(*) as count \n", + " FROM `{table_name}` \n", + " GROUP BY `{column}` \n", + " LIMIT 3\n", + " \"\"\")\n", + " print(f\"\\nSample values in {column}:\")\n", + " sample_df.show(truncate=False)\n", + "\n", + "total_rows = spark.sql(f\"SELECT COUNT(*) as count FROM `{table_name}`\").collect()[0][\"count\"]\n", + "print(f\"\\nTable size: {total_rows:,} total rows\")\n", + "\n", + "print(f\"Optimized streaming tokenization completed for {len(pii_columns)} columns\")" + ] } ], "metadata": { @@ -16,4 +287,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/setup.sh b/setup.sh index ae69012..d02fcbb 100755 --- a/setup.sh +++ b/setup.sh @@ -54,11 +54,6 @@ load_config() { echo "✓ SKYFLOW_VAULT_ID loaded from .env.local" fi - export SKYFLOW_ACCOUNT_ID=${DEFAULT_SKYFLOW_ACCOUNT_ID} - if [[ -n "$SKYFLOW_ACCOUNT_ID" ]]; then - echo "✓ SKYFLOW_ACCOUNT_ID loaded from .env.local" - fi - export SKYFLOW_PAT_TOKEN=${DEFAULT_SKYFLOW_PAT_TOKEN} if [[ -n "$SKYFLOW_PAT_TOKEN" ]]; then echo "✓ SKYFLOW_PAT_TOKEN loaded from .env.local" @@ -92,46 +87,25 @@ setup_uc_connections() { local sql_content=$(cat "sql/setup/create_uc_connections.sql") local processed_content=$(substitute_variables "$sql_content") - # Split into tokenization and detokenization connection statements - local tokenize_sql=$(echo "$processed_content" | sed -n '/CREATE CONNECTION.*skyflow_tokenize_conn/,/);/p') - local detokenize_sql=$(echo "$processed_content" | sed -n '/CREATE CONNECTION.*skyflow_detokenize_conn/,/);$/p') - - # Execute tokenization connection creation with detailed logging (direct API call to avoid catalog context) - echo "Executing tokenization connection SQL without catalog context..." - local tokenize_response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/sql/statements" \ - -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ - -H "Content-Type: application/json" \ - -d "{\"statement\":$(echo "$tokenize_sql" | python3 -c 'import json,sys; print(json.dumps(sys.stdin.read()))'),\"warehouse_id\":\"${WAREHOUSE_ID}\"}") - - if echo "$tokenize_response" | grep -q '"state":"SUCCEEDED"'; then - echo "✓ Created UC tokenization connection: skyflow_tokenize_conn" - local tokenize_success=true - else - echo "❌ ERROR: Tokenization connection creation failed" - echo "SQL statement was: $tokenize_sql" - echo "Response: $tokenize_response" - local tokenize_success=false - fi - - # Execute detokenization connection creation with detailed logging (direct API call to avoid catalog context) - echo "Executing detokenization connection SQL without catalog context..." - local detokenize_response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/sql/statements" \ + # Execute consolidated connection creation with detailed logging (direct API call to avoid catalog context) + echo "Executing consolidated Skyflow connection SQL without catalog context..." + local connection_response=$(curl -s -X POST "${DATABRICKS_HOST}/api/2.0/sql/statements" \ -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ -H "Content-Type: application/json" \ - -d "{\"statement\":$(echo "$detokenize_sql" | python3 -c 'import json,sys; print(json.dumps(sys.stdin.read()))'),\"warehouse_id\":\"${WAREHOUSE_ID}\"}") + -d "{\"statement\":$(echo "$processed_content" | python3 -c 'import json,sys; print(json.dumps(sys.stdin.read()))'),\"warehouse_id\":\"${WAREHOUSE_ID}\"}") - if echo "$detokenize_response" | grep -q '"state":"SUCCEEDED"'; then - echo "✓ Created UC detokenization connection: skyflow_detokenize_conn" - local detokenize_success=true + if echo "$connection_response" | grep -q '"state":"SUCCEEDED"'; then + echo "✓ Created UC consolidated connection: skyflow_conn" + local connection_success=true else - echo "❌ ERROR: Detokenization connection creation failed" - echo "SQL statement was: $detokenize_sql" - echo "Response: $detokenize_response" - local detokenize_success=false + echo "❌ ERROR: Consolidated connection creation failed" + echo "SQL statement was: $processed_content" + echo "Response: $connection_response" + local connection_success=false fi - # Verify connections actually exist after creation - echo "Verifying UC connections were actually created..." + # Verify connection actually exists after creation + echo "Verifying UC connection was actually created..." local actual_connections=$(curl -s -H "Authorization: Bearer ${DATABRICKS_TOKEN}" \ "${DATABRICKS_HOST}/api/2.1/unity-catalog/connections" | \ python3 -c " @@ -144,27 +118,20 @@ except: print('') ") - if echo "$actual_connections" | grep -q "skyflow_tokenize_conn"; then - echo "✓ Verified skyflow_tokenize_conn exists" - else - echo "❌ skyflow_tokenize_conn NOT FOUND after creation" - tokenize_success=false - fi - - if echo "$actual_connections" | grep -q "skyflow_detokenize_conn"; then - echo "✓ Verified skyflow_detokenize_conn exists" + if echo "$actual_connections" | grep -q "skyflow_conn"; then + echo "✓ Verified skyflow_conn exists" else - echo "❌ skyflow_detokenize_conn NOT FOUND after creation" - detokenize_success=false + echo "❌ skyflow_conn NOT FOUND after creation" + connection_success=false fi - # Return success only if both connections succeeded - if [ "$tokenize_success" = true ] && [ "$detokenize_success" = true ]; then - echo "✓ Both UC connections created successfully via SQL" + # Return success only if connection succeeded + if [ "$connection_success" = true ]; then + echo "✓ UC consolidated connection created successfully via SQL" return 0 else - echo "❌ Failed to create required UC connections" - echo "Both connections must be created successfully for setup to proceed" + echo "❌ Failed to create required UC connection" + echo "Connection must be created successfully for setup to proceed" exit 1 fi } @@ -192,13 +159,13 @@ setup_uc_secrets() { echo "✓ Created secrets scope successfully" fi - # Create individual secrets + # Create individual secrets (only ones actually needed for tokenization/detokenization) local secrets=( "skyflow_pat_token:${SKYFLOW_PAT_TOKEN}" - "skyflow_account_id:${SKYFLOW_ACCOUNT_ID}" "skyflow_vault_url:${SKYFLOW_VAULT_URL}" "skyflow_vault_id:${SKYFLOW_VAULT_ID}" "skyflow_table:${SKYFLOW_TABLE}" + "skyflow_table_column:${SKYFLOW_TABLE_COLUMN}" ) for secret_pair in "${secrets[@]}"; do @@ -410,7 +377,7 @@ except Exception as e: # Wait for run to complete echo "Waiting for tokenization to complete..." - local max_wait=300 # 5 minutes + local max_wait=900 # 15 minutes local wait_time=0 while [[ $wait_time -lt $max_wait ]]; do @@ -748,8 +715,8 @@ Resources created: 3. Tokenization notebook: - /Workspace/Shared/${PREFIX}_tokenize_table (serverless-optimized) 4. Unity Catalog Infrastructure: - - SQL-created HTTP connections: skyflow_tokenize_conn, skyflow_detokenize_conn - - UC-backed secrets scope: skyflow-secrets (contains PAT token, account ID, vault details) + - SQL-created HTTP connection: skyflow_conn (consolidated for tokenization and detokenization) + - UC-backed secrets scope: skyflow-secrets (contains PAT token, vault details) - Bearer token authentication with proper secret() references 5. Pure SQL Functions: - ${PREFIX}_catalog.default.${PREFIX}_skyflow_uc_detokenize (direct Skyflow API via UC connections) @@ -818,7 +785,7 @@ import json, sys try: data = json.load(sys.stdin) names = [conn['name'] for conn in data.get('connections', []) - if conn['name'] in ['skyflow_tokenize_conn', 'skyflow_detokenize_conn']] + if conn['name'] in ['skyflow_conn', 'skyflow_tokenize_conn', 'skyflow_detokenize_conn']] print('\n'.join(names)) except Exception as e: print(f'Error extracting names: {e}') @@ -1009,9 +976,11 @@ import sys, json try: data = json.load(sys.stdin) connections = [c['name'] for c in data.get('connections', [])] + skyflow_conn_exists = 'skyflow_conn' in connections + # Also check for old connections in case of partial migration tokenize_exists = 'skyflow_tokenize_conn' in connections detokenize_exists = 'skyflow_detokenize_conn' in connections - print('true' if (tokenize_exists or detokenize_exists) else 'false') + print('true' if (skyflow_conn_exists or tokenize_exists or detokenize_exists) else 'false') except: print('false') ") diff --git a/sql/setup/create_uc_connections.sql b/sql/setup/create_uc_connections.sql index 700e4fd..6ea591b 100644 --- a/sql/setup/create_uc_connections.sql +++ b/sql/setup/create_uc_connections.sql @@ -1,20 +1,11 @@ -- Unity Catalog HTTP connections for Skyflow API integration -- These must be created without catalog context (global metastore resources) --- Tokenization connection -CREATE CONNECTION IF NOT EXISTS skyflow_tokenize_conn TYPE HTTP +-- Single consolidated Skyflow connection for both tokenization and detokenization +CREATE CONNECTION IF NOT EXISTS skyflow_conn TYPE HTTP OPTIONS ( host '${SKYFLOW_VAULT_URL}', port 443, - base_path '/v1/vaults/${SKYFLOW_VAULT_ID}/${SKYFLOW_TABLE}', - bearer_token secret('skyflow-secrets', 'skyflow_pat_token') -); - --- Detokenization connection -CREATE CONNECTION IF NOT EXISTS skyflow_detokenize_conn TYPE HTTP -OPTIONS ( - host '${SKYFLOW_VAULT_URL}', - port 443, - base_path '/v1/vaults/${SKYFLOW_VAULT_ID}', + base_path '/v1/vaults', bearer_token secret('skyflow-secrets', 'skyflow_pat_token') ); \ No newline at end of file diff --git a/sql/setup/setup_uc_connections_api.sql b/sql/setup/setup_uc_connections_api.sql index 39309e8..f24e161 100644 --- a/sql/setup/setup_uc_connections_api.sql +++ b/sql/setup/setup_uc_connections_api.sql @@ -43,13 +43,12 @@ RETURN get_json_object( get_json_object( http_request( - conn => 'skyflow_detokenize_conn', + conn => 'skyflow_conn', method => 'POST', - path => '/detokenize', + path => '${SKYFLOW_VAULT_ID}/detokenize', headers => map( 'Content-Type', 'application/json', - 'Accept', 'application/json', - 'X-SKYFLOW-ACCOUNT-ID', '${SKYFLOW_ACCOUNT_ID}' + 'Accept', 'application/json' ), json => concat( '{"detokenizationParameters":[{"token":"',