From 09907a57f8fde101d6041c843adaca21aa827bf4 Mon Sep 17 00:00:00 2001 From: Jilks Smith Date: Thu, 26 Mar 2026 12:00:28 +0300 Subject: [PATCH 1/3] Add enterprise analytics tutorial --- enterprise-analytics/README.md | 374 ++++++++++++++++++ enterprise-analytics/predict.py | 173 ++++++++ .../tools/prepare_datasets.py | 122 ++++++ enterprise-analytics/train.py | 310 +++++++++++++++ 4 files changed, 979 insertions(+) create mode 100644 enterprise-analytics/README.md create mode 100644 enterprise-analytics/predict.py create mode 100644 enterprise-analytics/tools/prepare_datasets.py create mode 100644 enterprise-analytics/train.py diff --git a/enterprise-analytics/README.md b/enterprise-analytics/README.md new file mode 100644 index 0000000..64e88b1 --- /dev/null +++ b/enterprise-analytics/README.md @@ -0,0 +1,374 @@ +# Enterprise Supply Chain Demand Forecasting + +Confidential multi-party analytics for enterprise supply chain optimization. Three competing retail companies collaboratively train a demand forecasting model inside a Trusted Execution Environment (TEE)—each contributes proprietary transaction data, but **no company ever sees another's raw data**. + +This example demonstrates: + +- **Secure Computation (aTLS)** — Attested TLS verifies the TEE hardware and software stack before any data is uploaded +- **Multi-Party Computation** — Three independent data providers each upload proprietary datasets into the same encrypted enclave +- **Real-World Data** — Uses the [UCI Online Retail II](https://www.kaggle.com/datasets/mashlyn/online-retail-ii-uci) dataset (real European e-commerce transactions) split across simulated companies +- **Enterprise Value** — Benchmark proves the consortium model outperforms any single-company model + +## Table of Contents + +- [Scenario](#scenario) +- [Dataset](#dataset) +- [Architecture](#architecture) +- [Setup Virtual Environment](#setup-virtual-environment) +- [Install](#install) +- [Train Model (Local)](#train-model-local) +- [Test Model (Local)](#test-model-local) +- [Testing with Prism (Multi-Party)](#testing-with-prism-multi-party) +- [Notes](#notes) + +## Scenario + +Three retail companies—**Company 1**, **Company 2**, and **Company 3**—compete in the same market. Each holds proprietary customer transaction data that is a trade secret. No company would willingly hand over raw sales data to a competitor or a third party. + +However, they all recognize that a **consortium demand forecast** trained on the combined market data would be far more accurate than any model they could train alone. Prism AI makes this possible: + +1. A **neutral Algorithm Provider** supplies the training code (`train.py`) +2. Each company acts as a **Data Provider**, uploading encrypted datasets into the TEE +3. The TEE runs the algorithm over all three datasets simultaneously +4. Only the **aggregated results** (trained model + benchmark report) exit the enclave +5. No company ever sees another's raw transactions + +### What Gets Produced + +| Output File | Description | +|---|---| +| `demand_model.ubj` | Trained XGBoost demand forecasting model | +| `benchmark_report.csv` | Consortium accuracy vs. individual company models | +| `feature_importance.csv` | Top predictive features ranked by gain | +| `monthly_forecast.csv` | 3-month forward demand prediction | + +## Dataset + +**UCI Online Retail II** — Real transactions from a UK-based online retailer (2009–2011). + +- ~1 million transactions across 4,000+ customers and 4,000+ products +- Covers 43 countries +- Features: Invoice, StockCode, Description, Quantity, InvoiceDate, Price, Customer ID, Country + +The `prepare_datasets.py` tool splits this into 3 company datasets by customer ID, simulating the real-world scenario where each retailer owns a disjoint slice of the market. + +**Source:** [Kaggle — Online Retail II UCI](https://www.kaggle.com/datasets/mashlyn/online-retail-ii-uci) + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Trusted Execution Environment (TEE) │ +│ AMD SEV-SNP / Intel TDX Hardware │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ In-Enclave Agent │ │ +│ │ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Company 1 │ │ Company 2 │ │ Company 3 │ │ │ +│ │ │ Dataset │ │ Dataset │ │ Dataset │ │ │ +│ │ │ (encrypted) │ │ (encrypted) │ │ (encrypted) │ │ │ +│ │ └──────┬──────┘ └──────┬──────┘ └──────┬──────��� │ │ +│ │ │ │ │ │ │ +│ │ └────────────────┼────────────────┘ │ │ +│ │ ▼ │ │ +│ │ ┌──────────────────┐ │ │ +│ │ │ train.py │ │ │ +│ │ │ (Algorithm) │ │ │ +│ │ └────────┬─────────┘ │ │ +│ │ ▼ │ │ +│ │ ┌────────────────────────┐ │ │ +│ │ │ Results: │ │ │ +│ │ │ • demand_model.ubj │ │ │ +│ │ │ • benchmark_report │ │ │ +│ │ │ • feature_importance │ │ │ +│ │ │ • monthly_forecast │ │ │ +│ │ └────────────────────────┘ │ │ +│ └───────────────────────────────────────────────────────────┘ │ +│ │ +│ Memory encrypted by hardware • Host/cloud has zero access │ +└─────────────────────────────────────────────────────────────────┘ + ▲ ▲ ▲ │ + aTLS │ aTLS │ aTLS │ │ aTLS + │ │ │ ▼ + ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ + │ Company 1 │ │ Company 2 │ │ Company 3 │ │ Result │ + │ (Data │ │ (Data │ │ (Data │ │ Consumer │ + │ Provider)│ │ Provider)│ │ Provider)│ │ │ + └───────────┘ └───────────┘ └───────────┘ └───────────┘ +``` + +**Key security guarantees:** + +- **aTLS (Attested TLS):** Each participant verifies the TEE's hardware attestation before uploading. The cryptographic quote proves the enclave is genuine AMD SEV-SNP/Intel TDX hardware running the exact agreed-upon algorithm. +- **Memory encryption:** All data inside the TEE is encrypted by the CPU. The cloud provider, hypervisor, and host OS have zero access. +- **No raw data exits:** Only aggregated model weights and statistical reports leave the enclave. Individual transaction records are destroyed. + +## Setup Virtual Environment + +```bash +python3 -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +## Install + +Fetch the data from Kaggle — [Online Retail II UCI](https://www.kaggle.com/datasets/mashlyn/online-retail-ii-uci) dataset: + +```bash +kaggle datasets download -d mashlyn/online-retail-ii-uci +``` + +To run the above command you need [kaggle cli](https://github.com/Kaggle/kaggle-api) installed and API credentials set up. Follow [this documentation](https://github.com/Kaggle/kaggle-api/blob/main/docs/README.md#kaggle-api). + +You will get `online-retail-ii-uci.zip` in the folder. + +Prepare the 3 company datasets: + +```bash +python tools/prepare_datasets.py online-retail-ii-uci.zip -d datasets +``` + +Expected output: + +``` +Loaded 1067371 rows from online_retail_II.xlsx +After cleaning: 824364 rows, 4384 customers +Company 1: 273421 transactions, 1461 customers, 38 countries +Company 2: 276583 transactions, 1462 customers, 40 countries +Company 3: 274360 transactions, 1461 customers, 39 countries + +Dataset preparation complete. 3 company datasets saved to 'datasets/' +``` + +## Train Model (Local) + +To train the consortium model locally: + +```bash +python train.py +``` + +The script loads all company CSVs from `datasets/`, builds time-series features, trains a consortium XGBoost model on the combined data, then benchmarks it against individual company models. Results are saved to `results/`. + +## Test Model (Local) + +Analyze the results and generate visualizations: + +```bash +python predict.py +``` + +Output includes benchmark comparisons, feature importance charts, and demand forecast summaries. + +## Testing with Prism (Multi-Party) + +Prism provides a web-based interface for managing multi-party computations with full role-based access control. This is the recommended approach for enterprise deployments. + +### Prerequisites + +1. **Clone and start Prism:** + + ```bash + git clone https://github.com/ultravioletrs/prism.git + cd prism + make run + ``` + +2. **Prepare datasets** (follow the same steps as above) + +3. **Build Cocos artifacts and generate keys:** + + ```bash + cd cocos + make all + ./build/cocos-cli keys -k="rsa" + ``` + +### Multi-Party Setup in Prism + +This section shows how to configure a true multi-party computation where different participants have distinct roles: + +#### 1. Create User Accounts + +Create accounts for each participant in the consortium: + +- **Algorithm Provider** — The neutral data scientist supplying the training code +- **Company 1 Data Provider** — Uploads company_1.csv +- **Company 2 Data Provider** — Uploads company_2.csv +- **Company 3 Data Provider** — Uploads company_3.csv +- **Result Consumer** — The consortium administrator who receives the output + +#### 2. Create a Workspace + +Create a workspace representing the consortium (e.g., "Retail Demand Consortium"). + +#### 3. Create a CVM + +Create a Confidential VM and wait for it to come online. + +#### 4. Create the Computation + +Create the computation and set the name and description (e.g., "Q1 Demand Forecast — Multi-Retailer Consortium"). + +Generate sha3-256 checksums for all assets: + +```bash +./build/cocos-cli checksum ../ai/enterprise-analytics/train.py +./build/cocos-cli checksum ../ai/enterprise-analytics/datasets/company_1.csv +./build/cocos-cli checksum ../ai/enterprise-analytics/datasets/company_2.csv +./build/cocos-cli checksum ../ai/enterprise-analytics/datasets/company_3.csv +``` + +#### 5. Add Computation Assets + +Add the algorithm and dataset assets in Prism using the file names and checksums: + +| Asset | File Name | Role | +|---|---|---| +| Algorithm | `train.py` | Algorithm Provider | +| Dataset 1 | `company_1.csv` | Data Provider (Company 1) | +| Dataset 2 | `company_2.csv` | Data Provider (Company 2) | +| Dataset 3 | `company_3.csv` | Data Provider (Company 3) | + +#### 6. Assign Participant Roles + +Use Prism's computation roles to assign each participant: + +- The **Algorithm Provider** can upload the algorithm but cannot see the datasets +- Each **Data Provider** can upload only their own dataset +- The **Result Consumer** can download results but cannot see raw data or the algorithm + +This enforces strict separation of concerns — no single participant has access to all assets. + +#### 7. Upload Public Keys + +Each participant uploads their public key (generated by `cocos-cli`) to enable encrypted uploads and result retrieval. + +### Run the Computation + +1. **Click "Run Computation"** and select an available CVM + +2. **Copy the agent port** and export it: + + ```bash + export AGENT_GRPC_URL=localhost: + ``` + +3. **Algorithm Provider uploads the algorithm:** + + ```bash + ./build/cocos-cli algo ../ai/enterprise-analytics/train.py ./private.pem -a python -r ../ai/enterprise-analytics/requirements.txt + ``` + +4. **Each company uploads their dataset independently:** + + Company 1: + ```bash + ./build/cocos-cli data ../ai/enterprise-analytics/datasets/company_1.csv ./private.pem + ``` + + Company 2: + ```bash + ./build/cocos-cli data ../ai/enterprise-analytics/datasets/company_2.csv ./private.pem + ``` + + Company 3: + ```bash + ./build/cocos-cli data ../ai/enterprise-analytics/datasets/company_3.csv ./private.pem + ``` + +5. **Monitor the computation** through the Prism web interface. Events will show algorithm upload, data uploads, computation running, and completion. + +6. **Result Consumer downloads the results:** + + ```bash + ./build/cocos-cli result ./private.pem + ``` + +### Analyze Results + +```bash +cp results.zip ../ai/enterprise-analytics/ +cd ../ai/enterprise-analytics +unzip results.zip -d results +python predict.py +``` + +## Understanding the Security Model + +### Attested TLS (aTLS) — How It Works + +``` + aTLS Handshake +┌──────────┐ ┌──────────────┐ +│ Client │ 1. TLS ClientHello ──────────────────▶ │ TEE Agent │ +│ (Data │ │ (Enclave) │ +│ Provider)│ 2. TLS ServerHello + Attestation ◀── │ │ +│ │ Quote (signed by CPU hardware) │ │ +│ │ │ │ +│ │ 3. Client VERIFIES: │ │ +│ │ ✓ Genuine AMD/Intel hardware │ │ +│ │ ✓ Correct software measurement │ │ +│ │ ✓ Enclave not tampered with │ │ +│ │ │ │ +│ │ 4. Encrypted data upload ─────────────▶ │ [Data is │ +│ │ (only if attestation passed) │ decrypted │ +│ │ │ ONLY inside│ +│ │ │ enclave] │ +└──────────┘ └──────────────┘ +``` + +With aTLS, the TLS handshake includes a **hardware attestation quote** generated by the CPU's secure processor. This quote: + +1. **Proves the hardware is genuine** — The attestation is signed by the CPU manufacturer's root key +2. **Includes a measurement of the software** — A SHA-256 hash of the entire software stack loaded into the enclave +3. **Cannot be forged** — Even a compromised hypervisor or cloud administrator cannot generate a valid quote + +### Multi-Party Data Flow + +``` + Company 1 Company 2 Company 3 + │ │ │ + │ aTLS + upload │ aTLS + upload │ aTLS + upload + ▼ ▼ ▼ +┌─────────────────────────────────────────────────┐ +│ TEE Enclave │ +│ │ +│ company_1.csv company_2.csv company_3.csv │ +│ │ │ │ │ +│ └──────────────┼──────────────┘ │ +│ ▼ │ +│ Combined DataFrame │ +│ │ │ +│ Feature Engineering │ +│ │ │ +│ XGBoost Training │ +│ │ │ +│ ┌─��─────┴───────┐ │ +│ ▼ ▼ │ +│ demand_model benchmark_report │ +│ (no raw data) (aggregated stats) │ +│ │ +│ ⚠ Raw CSVs destroyed after computation │ +└──────────────────────┬───────────────────────────┘ + │ + ▼ aTLS download + Result Consumer +``` + +**Critical security properties:** + +- Each company's data is encrypted in transit (aTLS) and at rest (hardware memory encryption) +- The algorithm cannot exfiltrate raw data — only the computation manifest's approved outputs leave the enclave +- Even the cloud provider and Prism platform operators have zero access to the data inside the TEE +- The benchmark report contains only aggregate metrics (MAE, RMSE, R²) — not individual transaction data + +## Notes + +- **Memory:** 8GB is sufficient for the Online Retail II dataset. Increase for larger enterprise datasets. +- **Runtime:** Training completes in approximately 2–5 minutes depending on hardware. +- **Scaling:** The same architecture supports any number of data providers. Simply add more `-data-paths` entries or additional Data Provider roles in Prism. +- **aTLS on real hardware:** When running on AMD SEV-SNP or Intel TDX servers, set `MANAGER_QEMU_ENABLE_SEV_SNP=true` and `-attested-tls-bool true` to get hardware-backed attestation. In development mode, attestation is simulated. +- **Dataset alternatives:** Any tabular sales/transaction dataset can be substituted. The feature engineering in `train.py` expects columns: `Invoice`, `StockCode`, `Description`, `Quantity`, `InvoiceDate`, `Price`, `Customer ID`, `Country`. + diff --git a/enterprise-analytics/predict.py b/enterprise-analytics/predict.py new file mode 100644 index 0000000..89aab12 --- /dev/null +++ b/enterprise-analytics/predict.py @@ -0,0 +1,173 @@ +""" +Enterprise Supply Chain Demand Forecasting - Inference / Result Analysis + +Loads the trained consortium demand model and produces evaluation metrics, +visualizations, and a summary report demonstrating the value of multi-party +collaborative analytics over single-company models. +""" + +import os +import sys + +import matplotlib +matplotlib.use("Agg") + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import xgboost as xgb + + +def load_results(results_dir: str): + """Load and display all result artifacts.""" + print("=" * 60) + print("ENTERPRISE ANALYTICS - CONSORTIUM RESULTS ANALYSIS") + print("=" * 60) + + # ── Benchmark Report ───────────────────────────────────────────────── + benchmark_path = os.path.join(results_dir, "benchmark_report.csv") + if os.path.exists(benchmark_path): + benchmark = pd.read_csv(benchmark_path) + print("\n📊 BENCHMARK: Consortium vs. Individual Company Models") + print("-" * 60) + print(benchmark.to_string(index=False)) + + # Calculate improvement + consortium_r2 = benchmark.loc[ + benchmark["Dataset"].str.contains("Consortium.*all", na=False), "R2" + ].values + solo_r2 = benchmark.loc[ + benchmark["Dataset"].str.contains("solo", na=False), "R2" + ].values + + if len(consortium_r2) > 0 and len(solo_r2) > 0: + avg_solo = solo_r2.mean() + improvement = ((consortium_r2[0] - avg_solo) / max(abs(avg_solo), 0.01)) * 100 + print(f"\n ✅ Consortium R² : {consortium_r2[0]:.4f}") + print(f" 📉 Avg Solo R² : {avg_solo:.4f}") + print(f" 📈 Improvement : {improvement:+.1f}%") + + # Plot benchmark comparison + plot_benchmark(benchmark, results_dir) + else: + print(f" ⚠ Benchmark report not found at {benchmark_path}") + + # ── Feature Importance ─────────────────────────────────────────────── + importance_path = os.path.join(results_dir, "feature_importance.csv") + if os.path.exists(importance_path): + importance = pd.read_csv(importance_path) + print("\n🔑 TOP PREDICTIVE FEATURES") + print("-" * 60) + print(importance.head(10).to_string(index=False)) + plot_feature_importance(importance, results_dir) + else: + print(f" ⚠ Feature importance not found at {importance_path}") + + # ── Demand Forecast ────────────────────────────────────────────────── + forecast_path = os.path.join(results_dir, "monthly_forecast.csv") + if os.path.exists(forecast_path): + forecast = pd.read_csv(forecast_path) + print("\n📈 DEMAND FORECAST SUMMARY (next 3 months)") + print("-" * 60) + summary = ( + forecast.groupby("MonthOffset") + .agg( + AvgDemand=("PredictedDemand", "mean"), + TotalDemand=("PredictedDemand", "sum"), + NumProducts=("StockCode", "nunique"), + ) + .reset_index() + ) + print(summary.to_string(index=False)) + else: + print(f" ⚠ Forecast not found at {forecast_path}") + + # ── Model Info ─────────────────────────────────────────────────────── + model_path = os.path.join(results_dir, "demand_model.ubj") + if os.path.exists(model_path): + model = xgb.Booster() + model.load_model(model_path) + print(f"\n🤖 Model loaded successfully: {model_path}") + print(f" Model attributes: {model.attributes()}") + + print("\n" + "=" * 60) + print("ANALYSIS COMPLETE") + print("=" * 60) + + +def plot_benchmark(benchmark: pd.DataFrame, results_dir: str): + """Plot benchmark comparison bar chart.""" + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + fig.suptitle( + "Consortium Model vs. Individual Company Models", + fontsize=14, + fontweight="bold", + ) + + for ax, metric in zip(axes, ["MAE", "RMSE", "R2"]): + data = benchmark[["Dataset", metric]].copy() + colors = [ + "#7c3aed" if "Consortium" in d and "all" in d + else "#14b8a6" if "solo" in d + else "#94a3b8" + for d in data["Dataset"] + ] + short_labels = [ + d.replace("Consortium (all companies)", "Consortium") + .replace("Consortium on ", "C→") + .replace(" (solo model)", " Solo") + for d in data["Dataset"] + ] + ax.barh(short_labels, data[metric], color=colors) + ax.set_xlabel(metric) + ax.set_title(metric) + + plt.tight_layout() + chart_path = os.path.join(results_dir, "benchmark_comparison.png") + plt.savefig(chart_path, dpi=150) + plt.close() + print(f" Saved chart: {chart_path}") + + +def plot_feature_importance(importance: pd.DataFrame, results_dir: str): + """Plot feature importance bar chart.""" + top = importance.head(10) + fig, ax = plt.subplots(figsize=(10, 5)) + ax.barh(top["Feature"][::-1], top["Importance"][::-1], color="#7c3aed") + ax.set_xlabel("Importance (Gain)") + ax.set_title("Top 10 Predictive Features - Consortium Model") + plt.tight_layout() + chart_path = os.path.join(results_dir, "feature_importance.png") + plt.savefig(chart_path, dpi=150) + plt.close() + print(f" Saved chart: {chart_path}") + + +def main(): + datasets_dir = "datasets" + results_dir = "results" + + if not os.path.isdir(results_dir): + print(f"Results directory {results_dir} not found") + return + + # Check for model and reports + model_f = None + for f in os.listdir(results_dir): + if f.endswith(".ubj"): + model_f = f + + if model_f is None: + # Check datasets dir as fallback + if os.path.isdir(datasets_dir): + for f in os.listdir(datasets_dir): + if f.endswith(".ubj"): + model_f = f + results_dir = datasets_dir + + load_results(results_dir) + + +if __name__ == "__main__": + main() + diff --git a/enterprise-analytics/tools/prepare_datasets.py b/enterprise-analytics/tools/prepare_datasets.py new file mode 100644 index 0000000..d5d7519 --- /dev/null +++ b/enterprise-analytics/tools/prepare_datasets.py @@ -0,0 +1,122 @@ +""" +Enterprise Supply Chain Demand Forecasting - Data Preparation + +Splits the UCI Online Retail II dataset into 3 separate company datasets +simulating a multi-party enterprise analytics consortium where retailers +collaborate on demand forecasting without exposing proprietary sales data. + +Dataset: UCI Online Retail II +Source: https://www.kaggle.com/datasets/mashlyn/online-retail-ii-uci + +Each company dataset contains: +- Transaction history for a disjoint set of customers +- Invoice dates, product codes, quantities, and unit prices +- Country-level geographic distribution +""" + +import argparse +import os +import random +import zipfile + +import pandas as pd + + +def load_dataset(zip_path: str) -> pd.DataFrame: + """Load and clean the Online Retail II dataset from a zip file.""" + with zipfile.ZipFile(zip_path, "r") as z: + xlsx_files = [f for f in z.namelist() if f.endswith(".xlsx")] + if not xlsx_files: + raise FileNotFoundError("No .xlsx file found in the zip archive") + + with z.open(xlsx_files[0]) as f: + df = pd.read_excel(f, engine="openpyxl") + + print(f"Loaded {len(df)} rows from {xlsx_files[0]}") + + # Basic cleaning + df = df.dropna(subset=["Customer ID", "Description"]) + df = df[df["Quantity"] > 0] + df = df[df["Price"] > 0] + df["Customer ID"] = df["Customer ID"].astype(int) + df["Revenue"] = df["Quantity"] * df["Price"] + df["InvoiceDate"] = pd.to_datetime(df["InvoiceDate"]) + df["YearMonth"] = df["InvoiceDate"].dt.to_period("M").astype(str) + + print(f"After cleaning: {len(df)} rows, {df['Customer ID'].nunique()} customers") + return df + + +def split_by_company(df: pd.DataFrame, n_companies: int = 3, seed: int = 42): + """Split dataset into n disjoint company datasets by customer ID.""" + random.seed(seed) + + customers = list(df["Customer ID"].unique()) + random.shuffle(customers) + + chunk_size = len(customers) // n_companies + company_customers = [] + for i in range(n_companies): + start = i * chunk_size + end = start + chunk_size if i < n_companies - 1 else len(customers) + company_customers.append(set(customers[start:end])) + + company_dfs = [] + for i, cust_set in enumerate(company_customers): + company_df = df[df["Customer ID"].isin(cust_set)].copy() + company_dfs.append(company_df) + print( + f"Company {i + 1}: {len(company_df)} transactions, " + f"{len(cust_set)} customers, " + f"{company_df['Country'].nunique()} countries" + ) + + return company_dfs + + +def save_datasets(company_dfs: list, output_dir: str): + """Save each company dataset as a CSV file.""" + os.makedirs(output_dir, exist_ok=True) + + for i, df in enumerate(company_dfs): + filename = f"company_{i + 1}.csv" + filepath = os.path.join(output_dir, filename) + df.to_csv(filepath, index=False) + print(f"Saved {filepath} ({len(df)} rows)") + + +def main(): + parser = argparse.ArgumentParser( + description="Prepare enterprise analytics datasets from UCI Online Retail II" + ) + parser.add_argument( + "zipfile", + type=str, + help="Path to the online-retail-ii-uci.zip file", + ) + parser.add_argument( + "-d", + "--destination", + type=str, + default="datasets", + help="Output directory (default: datasets)", + ) + parser.add_argument( + "-n", + "--num-companies", + type=int, + default=3, + help="Number of companies to split data into (default: 3)", + ) + args = parser.parse_args() + + df = load_dataset(args.zipfile) + company_dfs = split_by_company(df, n_companies=args.num_companies) + save_datasets(company_dfs, args.destination) + + print(f"\nDataset preparation complete. {args.num_companies} company datasets saved to '{args.destination}/'") + + +if __name__ == "__main__": + main() + diff --git a/enterprise-analytics/train.py b/enterprise-analytics/train.py new file mode 100644 index 0000000..2b6e29b --- /dev/null +++ b/enterprise-analytics/train.py @@ -0,0 +1,310 @@ +""" +Enterprise Supply Chain Demand Forecasting - Training Algorithm + +This algorithm runs inside a Trusted Execution Environment (TEE) on the +Prism AI / Cocos platform. It receives proprietary sales datasets from +multiple competing retailers and trains a unified demand forecasting model +without any party seeing another's raw data. + +Scenario: + Three retail companies contribute their transaction histories to jointly + train an XGBoost model that predicts monthly product demand. The combined + model outperforms any single-company model because it captures broader + market signals, seasonal trends, and cross-geographic demand patterns. + +Inputs (uploaded as datasets/): + - company_1.csv, company_2.csv, company_3.csv + +Outputs (saved to results/): + - demand_model.ubj : Trained XGBoost model + - benchmark_report.csv : Per-company vs. consortium accuracy + - feature_importance.csv : Top predictive features + - monthly_forecast.csv : 3-month forward demand forecast +""" + +import os + +os.environ["OPENBLAS_L2_SIZE"] = "1024" + +import warnings + +import numpy as np +import pandas as pd +import xgboost as xgb +from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder + +warnings.filterwarnings("ignore") + +DATASETS_DIR = "datasets" +RESULTS_DIR = "results" + + +# ── Feature Engineering ────────────────────────────────────────────────────── + +def build_features(df: pd.DataFrame) -> pd.DataFrame: + """Create time-series demand features from raw transaction data.""" + df = df.copy() + df["InvoiceDate"] = pd.to_datetime(df["InvoiceDate"]) + df["YearMonth"] = df["InvoiceDate"].dt.to_period("M") + df["Month"] = df["InvoiceDate"].dt.month + df["DayOfWeek"] = df["InvoiceDate"].dt.dayofweek + df["WeekOfYear"] = df["InvoiceDate"].dt.isocalendar().week.astype(int) + + # Aggregate to monthly product-level demand + monthly = ( + df.groupby(["StockCode", "YearMonth", "Country", "Month", "WeekOfYear"]) + .agg( + TotalQuantity=("Quantity", "sum"), + TotalRevenue=("Revenue", "sum"), + AvgPrice=("Price", "mean"), + NumTransactions=("Invoice", "nunique"), + NumCustomers=("Customer ID", "nunique"), + ) + .reset_index() + ) + + # Encode categoricals + le_stock = LabelEncoder() + le_country = LabelEncoder() + monthly["StockCode_enc"] = le_stock.fit_transform(monthly["StockCode"].astype(str)) + monthly["Country_enc"] = le_country.fit_transform(monthly["Country"].astype(str)) + + # Sort for lag features + monthly["YearMonth_str"] = monthly["YearMonth"].astype(str) + monthly = monthly.sort_values(["StockCode", "YearMonth_str"]) + + # Lag features (previous month demand) + monthly["Lag1_Quantity"] = monthly.groupby("StockCode")["TotalQuantity"].shift(1) + monthly["Lag2_Quantity"] = monthly.groupby("StockCode")["TotalQuantity"].shift(2) + monthly["Lag1_Revenue"] = monthly.groupby("StockCode")["TotalRevenue"].shift(1) + + # Rolling averages + monthly["Rolling3_Quantity"] = ( + monthly.groupby("StockCode")["TotalQuantity"] + .transform(lambda x: x.rolling(3, min_periods=1).mean()) + ) + + monthly = monthly.dropna(subset=["Lag1_Quantity"]) + + return monthly + + +FEATURE_COLS = [ + "StockCode_enc", + "Country_enc", + "Month", + "WeekOfYear", + "AvgPrice", + "NumTransactions", + "NumCustomers", + "Lag1_Quantity", + "Lag2_Quantity", + "Lag1_Revenue", + "Rolling3_Quantity", +] +TARGET_COL = "TotalQuantity" + + +# ── Training ───────────────────────────────────────────────────────────────── + +def train_model(X_train, y_train, X_val, y_val): + """Train XGBoost demand forecasting model.""" + dtrain = xgb.DMatrix(X_train, label=y_train) + dval = xgb.DMatrix(X_val, label=y_val) + + params = { + "objective": "reg:squarederror", + "eval_metric": "rmse", + "eta": 0.1, + "max_depth": 6, + "subsample": 0.8, + "colsample_bytree": 0.8, + "min_child_weight": 5, + "seed": 42, + } + + model = xgb.train( + params, + dtrain, + num_boost_round=300, + evals=[(dval, "validation")], + early_stopping_rounds=15, + verbose_eval=50, + ) + + return model + + +def evaluate_model(model, X, y, label=""): + """Evaluate model and return metrics dict.""" + dmatrix = xgb.DMatrix(X) + preds = model.predict(dmatrix) + mae = mean_absolute_error(y, preds) + rmse = np.sqrt(mean_squared_error(y, preds)) + r2 = r2_score(y, preds) + print(f" [{label}] MAE: {mae:.2f}, RMSE: {rmse:.2f}, R²: {r2:.4f}") + return {"Dataset": label, "MAE": mae, "RMSE": rmse, "R2": r2} + + +# ── Main ───────────────────────────────────────────────────────────────────── + +def main(): + if not os.path.isdir(DATASETS_DIR): + print(f"Dataset directory {DATASETS_DIR} not found") + return + + os.makedirs(RESULTS_DIR, exist_ok=True) + + # ── Load all company datasets ──────────────────────────────────────── + company_files = sorted( + [f for f in os.listdir(DATASETS_DIR) if f.endswith(".csv")] + ) + + if not company_files: + print("No CSV datasets found in datasets/") + return + + print(f"Found {len(company_files)} company datasets: {company_files}") + print("=" * 60) + + company_dfs = {} + for f in company_files: + name = os.path.splitext(f)[0] + df = pd.read_csv(os.path.join(DATASETS_DIR, f)) + company_dfs[name] = df + print(f" {name}: {len(df)} transactions") + + # ── Build features per company and combined ────────────────────────── + print("\nBuilding features...") + company_features = {} + all_features = [] + + for name, df in company_dfs.items(): + features = build_features(df) + company_features[name] = features + all_features.append(features) + print(f" {name}: {len(features)} monthly demand records") + + combined = pd.concat(all_features, ignore_index=True) + print(f" Combined consortium: {len(combined)} monthly demand records") + + # ── Train consortium model (all companies together) ────────────────── + print("\n" + "=" * 60) + print("TRAINING CONSORTIUM MODEL (all companies)") + print("=" * 60) + + X = combined[FEATURE_COLS].values + y = combined[TARGET_COL].values + + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + X_train, X_val, y_train, y_val = train_test_split( + X_train, y_train, test_size=0.15, random_state=42 + ) + + consortium_model = train_model(X_train, y_train, X_val, y_val) + + # ── Evaluate consortium model ──────────────────────────────────────── + print("\nConsortium model evaluation:") + benchmark_rows = [] + consortium_metrics = evaluate_model( + consortium_model, X_test, y_test, "Consortium (all companies)" + ) + benchmark_rows.append(consortium_metrics) + + # Evaluate on each company's data separately + for name, features in company_features.items(): + Xc = features[FEATURE_COLS].values + yc = features[TARGET_COL].values + metrics = evaluate_model(consortium_model, Xc, yc, f"Consortium on {name}") + benchmark_rows.append(metrics) + + # ── Train individual company models for comparison ─────────────────── + print("\n" + "=" * 60) + print("TRAINING INDIVIDUAL COMPANY MODELS (for benchmark)") + print("=" * 60) + + for name, features in company_features.items(): + print(f"\n Training model for {name}...") + Xc = features[FEATURE_COLS].values + yc = features[TARGET_COL].values + + if len(Xc) < 50: + print(f" Skipping {name}: insufficient data ({len(Xc)} records)") + continue + + Xc_train, Xc_test, yc_train, yc_test = train_test_split( + Xc, yc, test_size=0.2, random_state=42 + ) + Xc_train, Xc_val, yc_train, yc_val = train_test_split( + Xc_train, yc_train, test_size=0.15, random_state=42 + ) + + individual_model = train_model(Xc_train, yc_train, Xc_val, yc_val) + metrics = evaluate_model( + individual_model, Xc_test, yc_test, f"{name} (solo model)" + ) + benchmark_rows.append(metrics) + + # ── Save results ───────────────────────────────────────────────────── + print("\n" + "=" * 60) + print("SAVING RESULTS") + print("=" * 60) + + # 1. Model + model_path = os.path.join(RESULTS_DIR, "demand_model.ubj") + consortium_model.save_model(model_path) + print(f" Saved model: {model_path}") + + # 2. Benchmark report + benchmark_df = pd.DataFrame(benchmark_rows) + benchmark_path = os.path.join(RESULTS_DIR, "benchmark_report.csv") + benchmark_df.to_csv(benchmark_path, index=False) + print(f" Saved benchmark: {benchmark_path}") + print("\n BENCHMARK RESULTS:") + print(benchmark_df.to_string(index=False)) + + # 3. Feature importance + importance = consortium_model.get_score(importance_type="gain") + importance_df = pd.DataFrame( + [{"Feature": FEATURE_COLS[int(k[1:])] if k.startswith("f") else k, + "Importance": v} + for k, v in importance.items()] + ).sort_values("Importance", ascending=False) + importance_path = os.path.join(RESULTS_DIR, "feature_importance.csv") + importance_df.to_csv(importance_path, index=False) + print(f"\n Saved feature importance: {importance_path}") + print(importance_df.to_string(index=False)) + + # 4. Forward forecast (next 3 months based on last known data) + last_month_data = combined.sort_values("YearMonth_str").groupby("StockCode_enc").tail(1) + forecast_rows = [] + for month_offset in range(1, 4): + forecast_input = last_month_data[FEATURE_COLS].copy() + forecast_input["Month"] = (forecast_input["Month"] + month_offset - 1) % 12 + 1 + dpred = xgb.DMatrix(forecast_input.values) + preds = consortium_model.predict(dpred) + for idx, (_, row) in enumerate(last_month_data.iterrows()): + forecast_rows.append({ + "StockCode": row["StockCode"], + "Country": row["Country"], + "MonthOffset": month_offset, + "PredictedDemand": max(0, preds[idx]), + }) + + forecast_df = pd.DataFrame(forecast_rows) + forecast_path = os.path.join(RESULTS_DIR, "monthly_forecast.csv") + forecast_df.to_csv(forecast_path, index=False) + print(f"\n Saved forecast: {forecast_path}") + + print("\n" + "=" * 60) + print("COMPUTATION COMPLETE") + print("=" * 60) + + +if __name__ == "__main__": + main() + From 8de2b680eca3dce459485dc8734e4577cf26131c Mon Sep 17 00:00:00 2001 From: Jilks Smith Date: Fri, 27 Mar 2026 12:31:16 +0300 Subject: [PATCH 2/3] Fix enterprise analytics tutorial --- .../tools/prepare_datasets.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/enterprise-analytics/tools/prepare_datasets.py b/enterprise-analytics/tools/prepare_datasets.py index d5d7519..cd33c30 100644 --- a/enterprise-analytics/tools/prepare_datasets.py +++ b/enterprise-analytics/tools/prepare_datasets.py @@ -26,13 +26,20 @@ def load_dataset(zip_path: str) -> pd.DataFrame: """Load and clean the Online Retail II dataset from a zip file.""" with zipfile.ZipFile(zip_path, "r") as z: xlsx_files = [f for f in z.namelist() if f.endswith(".xlsx")] - if not xlsx_files: - raise FileNotFoundError("No .xlsx file found in the zip archive") - - with z.open(xlsx_files[0]) as f: - df = pd.read_excel(f, engine="openpyxl") - - print(f"Loaded {len(df)} rows from {xlsx_files[0]}") + csv_files = [f for f in z.namelist() if f.endswith(".csv")] + + if xlsx_files: + with z.open(xlsx_files[0]) as f: + df = pd.read_excel(f, engine="openpyxl") + data_file = xlsx_files[0] + elif csv_files: + with z.open(csv_files[0]) as f: + df = pd.read_csv(f, encoding="utf-8") + data_file = csv_files[0] + else: + raise FileNotFoundError("No .xlsx or .csv file found in the zip archive") + + print(f"Loaded {len(df)} rows from {data_file}") # Basic cleaning df = df.dropna(subset=["Customer ID", "Description"]) From 4d47b031459a1bb74f6a32b11a03ae6eaa6ea0d6 Mon Sep 17 00:00:00 2001 From: Jilks Smith Date: Fri, 27 Mar 2026 13:37:55 +0300 Subject: [PATCH 3/3] Fix enterprise tutorial --- enterprise-analytics/predict.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/enterprise-analytics/predict.py b/enterprise-analytics/predict.py index 89aab12..01403a2 100644 --- a/enterprise-analytics/predict.py +++ b/enterprise-analytics/predict.py @@ -56,18 +56,18 @@ def load_results(results_dir: str): importance_path = os.path.join(results_dir, "feature_importance.csv") if os.path.exists(importance_path): importance = pd.read_csv(importance_path) - print("\n🔑 TOP PREDICTIVE FEATURES") + print("\n TOP PREDICTIVE FEATURES") print("-" * 60) print(importance.head(10).to_string(index=False)) plot_feature_importance(importance, results_dir) else: - print(f" ⚠ Feature importance not found at {importance_path}") + print(f" Feature importance not found at {importance_path}") # ── Demand Forecast ────────────────────────────────────────────────── forecast_path = os.path.join(results_dir, "monthly_forecast.csv") if os.path.exists(forecast_path): forecast = pd.read_csv(forecast_path) - print("\n📈 DEMAND FORECAST SUMMARY (next 3 months)") + print("\n DEMAND FORECAST SUMMARY (next 3 months)") print("-" * 60) summary = ( forecast.groupby("MonthOffset") @@ -80,14 +80,14 @@ def load_results(results_dir: str): ) print(summary.to_string(index=False)) else: - print(f" ⚠ Forecast not found at {forecast_path}") + print(f" Forecast not found at {forecast_path}") # ── Model Info ─────────────────────────────────────────────────────── model_path = os.path.join(results_dir, "demand_model.ubj") if os.path.exists(model_path): model = xgb.Booster() model.load_model(model_path) - print(f"\n🤖 Model loaded successfully: {model_path}") + print(f"\n Model loaded successfully: {model_path}") print(f" Model attributes: {model.attributes()}") print("\n" + "=" * 60)